From 32da24366e5dbb6a7c271540cf4fe5198ba2837a Mon Sep 17 00:00:00 2001 From: SumGuo Date: Mon, 23 Dec 2024 20:30:40 +0800 Subject: [PATCH 01/82] 4424 --- deepmd/pt/utils/dataset.py | 12 ++++++++++++ deepmd/pt/utils/stat.py | 35 +++++++++++++++++++++++++++++++++++ 2 files changed, 47 insertions(+) diff --git a/deepmd/pt/utils/dataset.py b/deepmd/pt/utils/dataset.py index 3043839308..df4f1fc6cd 100644 --- a/deepmd/pt/utils/dataset.py +++ b/deepmd/pt/utils/dataset.py @@ -40,6 +40,18 @@ def __getitem__(self, index): b_data["natoms"] = self._natoms_vec return b_data + def _build_element_to_frames(self): + """Mapping element types to frame indexes""" + element_to_frames = {element: [] for element in range(self._ntypes)} + for frame_idx in range(len(self)): + frame_data = self._data_system.get_item_torch(frame_idx) + + elements = frame_data["atype"] + for element in set(elements): + if len(element_to_frames[element]) < 10: + element_to_frames[element].append(frame_idx) + return element_to_frames + def add_data_requirement(self, data_requirement: list[DataRequirementItem]) -> None: """Add data requirement for this data system.""" for data_item in data_requirement: diff --git a/deepmd/pt/utils/stat.py b/deepmd/pt/utils/stat.py index 1c5e3f1c52..82cb816f7b 100644 --- a/deepmd/pt/utils/stat.py +++ b/deepmd/pt/utils/stat.py @@ -82,6 +82,41 @@ def make_stat_input(datasets, dataloaders, nbatches): sys_stat[key] = torch.cat(sys_stat[key], dim=0) dict_to_device(sys_stat) lst.append(sys_stat) + + all_elements = set() + if datasets and hasattr(datasets[0], 'element_to_frames'): + all_elements.update(datasets[0].element_to_frames.keys()) + print('we want', all_elements) + + collected_elements = set() + for sys_stat in lst: + if 'atype' in sys_stat: + collected_elements.update(np.unique(sys_stat['atype'].cpu().numpy())) + missing_elements = all_elements - collected_elements + + for missing_element in missing_elements: + for i, dataset in enumerate(datasets): + if hasattr(dataset, 'element_to_frames'): + frame_indices = dataset.element_to_frames.get(missing_element, []) + for frame_idx in frame_indices: + if len(lst[i]['atype']) >= nbatches: + break + frame_data = dataset[frame_idx] + for key in frame_data: + if key not in lst[i]: + lst[i][key] = [] + lst[i][key].append(frame_data[key]) + + collected_elements = set() + for sys_stat in lst: + if 'atype' in sys_stat: + collected_elements.update(np.unique(sys_stat['atype'].cpu().numpy())) + + for sys_stat in lst: + for key in sys_stat: + if isinstance(sys_stat[key], list) and isinstance(sys_stat[key][0], torch.Tensor): + sys_stat[key] = torch.cat(sys_stat[key], dim=0) + return lst From adf2315d57730d2cdfe2a4244a2d2e167bcb45e2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 23 Dec 2024 12:35:16 +0000 Subject: [PATCH 02/82] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- deepmd/pt/utils/dataset.py | 6 +++--- deepmd/pt/utils/stat.py | 26 ++++++++++++++++---------- 2 files changed, 19 insertions(+), 13 deletions(-) diff --git a/deepmd/pt/utils/dataset.py b/deepmd/pt/utils/dataset.py index df4f1fc6cd..b5fd8c58a0 100644 --- a/deepmd/pt/utils/dataset.py +++ b/deepmd/pt/utils/dataset.py @@ -42,13 +42,13 @@ def __getitem__(self, index): def _build_element_to_frames(self): """Mapping element types to frame indexes""" - element_to_frames = {element: [] for element in range(self._ntypes)} + element_to_frames = {element: [] for element in range(self._ntypes)} for frame_idx in range(len(self)): frame_data = self._data_system.get_item_torch(frame_idx) - elements = frame_data["atype"] + elements = frame_data["atype"] for element in set(elements): - if len(element_to_frames[element]) < 10: + if len(element_to_frames[element]) < 10: element_to_frames[element].append(frame_idx) return element_to_frames diff --git a/deepmd/pt/utils/stat.py b/deepmd/pt/utils/stat.py index 82cb816f7b..e3e7410a21 100644 --- a/deepmd/pt/utils/stat.py +++ b/deepmd/pt/utils/stat.py @@ -84,22 +84,24 @@ def make_stat_input(datasets, dataloaders, nbatches): lst.append(sys_stat) all_elements = set() - if datasets and hasattr(datasets[0], 'element_to_frames'): + if datasets and hasattr(datasets[0], "element_to_frames"): all_elements.update(datasets[0].element_to_frames.keys()) - print('we want', all_elements) + print("we want", all_elements) collected_elements = set() for sys_stat in lst: - if 'atype' in sys_stat: - collected_elements.update(np.unique(sys_stat['atype'].cpu().numpy())) + if "atype" in sys_stat: + collected_elements.update(np.unique(sys_stat["atype"].cpu().numpy())) missing_elements = all_elements - collected_elements for missing_element in missing_elements: for i, dataset in enumerate(datasets): - if hasattr(dataset, 'element_to_frames'): - frame_indices = dataset.element_to_frames.get(missing_element, []) + if hasattr(dataset, "element_to_frames"): + frame_indices = dataset.element_to_frames.get( + missing_element, [] + ) for frame_idx in frame_indices: - if len(lst[i]['atype']) >= nbatches: + if len(lst[i]["atype"]) >= nbatches: break frame_data = dataset[frame_idx] for key in frame_data: @@ -109,12 +111,16 @@ def make_stat_input(datasets, dataloaders, nbatches): collected_elements = set() for sys_stat in lst: - if 'atype' in sys_stat: - collected_elements.update(np.unique(sys_stat['atype'].cpu().numpy())) + if "atype" in sys_stat: + collected_elements.update( + np.unique(sys_stat["atype"].cpu().numpy()) + ) for sys_stat in lst: for key in sys_stat: - if isinstance(sys_stat[key], list) and isinstance(sys_stat[key][0], torch.Tensor): + if isinstance(sys_stat[key], list) and isinstance( + sys_stat[key][0], torch.Tensor + ): sys_stat[key] = torch.cat(sys_stat[key], dim=0) return lst From 4f6f63d13af6b187766c3f0c46dc7eeb8f3b995f Mon Sep 17 00:00:00 2001 From: SumGuo Date: Thu, 26 Dec 2024 14:37:16 +0800 Subject: [PATCH 03/82] issues4424-2 --- deepmd/pt/utils/dataset.py | 27 +++++---- deepmd/pt/utils/stat.py | 81 ++++++++++++++----------- source/tests/pt/test_make_stat_input.py | 50 +++++++++++++++ 3 files changed, 110 insertions(+), 48 deletions(-) create mode 100644 source/tests/pt/test_make_stat_input.py diff --git a/deepmd/pt/utils/dataset.py b/deepmd/pt/utils/dataset.py index df4f1fc6cd..e7ed73946e 100644 --- a/deepmd/pt/utils/dataset.py +++ b/deepmd/pt/utils/dataset.py @@ -30,6 +30,7 @@ def __init__(self, system: str, type_map: Optional[list[str]] = None) -> None: self._ntypes = self._data_system.get_ntypes() self._natoms = self._data_system.get_natoms() self._natoms_vec = self._data_system.get_natoms_vec(self._ntypes) + self.element_to_frames, self.get_all_atype = self._build_element_to_frames() def __len__(self) -> int: return self._data_system.nframes @@ -40,17 +41,21 @@ def __getitem__(self, index): b_data["natoms"] = self._natoms_vec return b_data - def _build_element_to_frames(self): - """Mapping element types to frame indexes""" - element_to_frames = {element: [] for element in range(self._ntypes)} - for frame_idx in range(len(self)): - frame_data = self._data_system.get_item_torch(frame_idx) - - elements = frame_data["atype"] - for element in set(elements): - if len(element_to_frames[element]) < 10: - element_to_frames[element].append(frame_idx) - return element_to_frames + def _build_element_to_frames(self): + """Build mapping from element types to frame indexes and return all unique element types.""" + element_to_frames = {element: [] for element in range(self._ntypes)} + all_elements = set() + all_frame_data = self._data_system.get_batch(self._data_system.nframes) + all_elements = np.unique(all_frame_data["type"]) + for i in range(len(self)): + for element in all_elements: + element_to_frames[element].append(i) + return element_to_frames, all_elements + + def get_frames_for_element(self, missing_element_name): + """Get the frames with the element type.""" + element_index = self._type_map.index(missing_element_name) + return self.element_to_frames.get(element_index, []) def add_data_requirement(self, data_requirement: list[DataRequirementItem]) -> None: """Add data requirement for this data system.""" diff --git a/deepmd/pt/utils/stat.py b/deepmd/pt/utils/stat.py index 82cb816f7b..4bd15f9e3b 100644 --- a/deepmd/pt/utils/stat.py +++ b/deepmd/pt/utils/stat.py @@ -36,9 +36,8 @@ log = logging.getLogger(__name__) - def make_stat_input(datasets, dataloaders, nbatches): - """Pack data for statistics. + """Pack data for statistics with all elements. Args: - dataset: A list of dataset to analyze. @@ -82,44 +81,52 @@ def make_stat_input(datasets, dataloaders, nbatches): sys_stat[key] = torch.cat(sys_stat[key], dim=0) dict_to_device(sys_stat) lst.append(sys_stat) - - all_elements = set() - if datasets and hasattr(datasets[0], 'element_to_frames'): - all_elements.update(datasets[0].element_to_frames.keys()) - print('we want', all_elements) - - collected_elements = set() - for sys_stat in lst: - if 'atype' in sys_stat: - collected_elements.update(np.unique(sys_stat['atype'].cpu().numpy())) - missing_elements = all_elements - collected_elements - - for missing_element in missing_elements: - for i, dataset in enumerate(datasets): - if hasattr(dataset, 'element_to_frames'): - frame_indices = dataset.element_to_frames.get(missing_element, []) - for frame_idx in frame_indices: - if len(lst[i]['atype']) >= nbatches: - break - frame_data = dataset[frame_idx] - for key in frame_data: - if key not in lst[i]: - lst[i][key] = [] - lst[i][key].append(frame_data[key]) - - collected_elements = set() - for sys_stat in lst: - if 'atype' in sys_stat: - collected_elements.update(np.unique(sys_stat['atype'].cpu().numpy())) - - for sys_stat in lst: - for key in sys_stat: - if isinstance(sys_stat[key], list) and isinstance(sys_stat[key][0], torch.Tensor): - sys_stat[key] = torch.cat(sys_stat[key], dim=0) + unique_elements = set() + all_element = set() + + for i in lst: + unique_values = np.unique(i['atype'].cpu().numpy()) + unique_elements.update(unique_values) + for i in datasets: + all_elements_in_dataset = i.get_all_atype + all_element.update(all_elements_in_dataset) + missing_element = all_element - unique_elements + for miss in missing_element: + for i in datasets: + if i.element_to_frames.get(miss, []) is not None: + frame_indices = i.element_to_frames.get(miss, []) + frame_data = i.__getitem__(frame_indices[0]) + break + else: + pass + sys_stat_new = {} + for dd in frame_data: + if dd == "type": + continue + if frame_data[dd] is None: + sys_stat_new[dd] = None + elif isinstance(frame_data[dd], np.ndarray): + if dd not in sys_stat_new: + sys_stat_new[dd] = [] + frame_data[dd] = torch.from_numpy(frame_data[dd]) + frame_data[dd] = frame_data[dd].unsqueeze(0) + sys_stat_new[dd].append(frame_data[dd]) + elif isinstance(stat_data[dd], np.float32): + sys_stat_new[dd] = frame_data[dd] + else: + pass + for key in sys_stat_new: + if isinstance(sys_stat_new[key], np.float32): + pass + elif sys_stat_new[key] is None or sys_stat_new[key][0] is None: + sys_stat_new[key] = None + elif isinstance(stat_data[dd], torch.Tensor): + sys_stat_new[key] = torch.cat(sys_stat_new[key], dim=0) + dict_to_device(sys_stat_new) + lst.append(sys_stat_new) return lst - def _restore_from_file( stat_file_path: DPPath, keys: list[str] = ["energy"], diff --git a/source/tests/pt/test_make_stat_input.py b/source/tests/pt/test_make_stat_input.py new file mode 100644 index 0000000000..fe2a4c9e51 --- /dev/null +++ b/source/tests/pt/test_make_stat_input.py @@ -0,0 +1,50 @@ +import unittest +import torch +from torch.utils.data import DataLoader +from deepmd.pt.utils.stat import make_stat_input + +class TestDataset: + def __init__(self, samples): + + self.samples = samples + self.element_to_frames = {} + for idx, sample in enumerate(samples): + atypes = sample['atype'] + for atype in atypes: + if atype not in self.element_to_frames: + self.element_to_frames[atype] = [] + self.element_to_frames[atype].append(idx) + + @property + def get_all_atype(self): + return set(self.element_to_frames.keys()) + + def __len__(self): + return len(self.samples) + + def __getitem__(self, idx): + sample = self.samples[idx] + return { + 'atype': torch.tensor(sample['atype'], dtype=torch.long), + 'energy': torch.tensor(sample['energy'], dtype=torch.float32), + } + +class TestMakeStatInput(unittest.TestCase): + def setUp(self): + self.system = TestDataset([ + {'atype': [1], 'energy': -1.0}, + {'atype': [2], 'energy': -2.0}, + ]) + self.datasets = [self.system] + self.dataloaders = [ + DataLoader(self.system, batch_size=1, shuffle=False), + ] + def test_make_stat_input(self): + nbatches = 1 + lst = make_stat_input(self.datasets, self.dataloaders, nbatches=nbatches) + all_elements = self.system.get_all_atype + unique_elements = {1,2} + self.assertEqual(unique_elements, all_elements, "make_stat_input miss elements") + +if __name__ == '__main__': + unittest.main() From b9bac38a3159322310eb28ef66cd7c4050452ec6 Mon Sep 17 00:00:00 2001 From: SumGuo Date: Thu, 26 Dec 2024 14:51:13 +0800 Subject: [PATCH 04/82] ll --- deepmd/pt/utils/dataset.py | 26 +++++------- deepmd/pt/utils/stat.py | 83 ++++++++++++++++++-------------------- 2 files changed, 51 insertions(+), 58 deletions(-) diff --git a/deepmd/pt/utils/dataset.py b/deepmd/pt/utils/dataset.py index e7ed73946e..d979224371 100644 --- a/deepmd/pt/utils/dataset.py +++ b/deepmd/pt/utils/dataset.py @@ -41,21 +41,17 @@ def __getitem__(self, index): b_data["natoms"] = self._natoms_vec return b_data - def _build_element_to_frames(self): - """Build mapping from element types to frame indexes and return all unique element types.""" - element_to_frames = {element: [] for element in range(self._ntypes)} - all_elements = set() - all_frame_data = self._data_system.get_batch(self._data_system.nframes) - all_elements = np.unique(all_frame_data["type"]) - for i in range(len(self)): - for element in all_elements: - element_to_frames[element].append(i) - return element_to_frames, all_elements - - def get_frames_for_element(self, missing_element_name): - """Get the frames with the element type.""" - element_index = self._type_map.index(missing_element_name) - return self.element_to_frames.get(element_index, []) + def _build_element_to_frames(self): + """Mapping element types to frame indexes""" + element_to_frames = {element: [] for element in range(self._ntypes)} + for frame_idx in range(len(self)): + frame_data = self._data_system.get_item_torch(frame_idx) + + elements = frame_data["atype"] + for element in set(elements): + if len(element_to_frames[element]) < 10: + element_to_frames[element].append(frame_idx) + return element_to_frames def add_data_requirement(self, data_requirement: list[DataRequirementItem]) -> None: """Add data requirement for this data system.""" diff --git a/deepmd/pt/utils/stat.py b/deepmd/pt/utils/stat.py index 4bd15f9e3b..1bf7811a5c 100644 --- a/deepmd/pt/utils/stat.py +++ b/deepmd/pt/utils/stat.py @@ -81,49 +81,46 @@ def make_stat_input(datasets, dataloaders, nbatches): sys_stat[key] = torch.cat(sys_stat[key], dim=0) dict_to_device(sys_stat) lst.append(sys_stat) - unique_elements = set() - all_element = set() - - for i in lst: - unique_values = np.unique(i['atype'].cpu().numpy()) - unique_elements.update(unique_values) - for i in datasets: - all_elements_in_dataset = i.get_all_atype - all_element.update(all_elements_in_dataset) - missing_element = all_element - unique_elements - for miss in missing_element: - for i in datasets: - if i.element_to_frames.get(miss, []) is not None: - frame_indices = i.element_to_frames.get(miss, []) - frame_data = i.__getitem__(frame_indices[0]) - break - else: - pass - sys_stat_new = {} - for dd in frame_data: - if dd == "type": - continue - if frame_data[dd] is None: - sys_stat_new[dd] = None - elif isinstance(frame_data[dd], np.ndarray): - if dd not in sys_stat_new: - sys_stat_new[dd] = [] - frame_data[dd] = torch.from_numpy(frame_data[dd]) - frame_data[dd] = frame_data[dd].unsqueeze(0) - sys_stat_new[dd].append(frame_data[dd]) - elif isinstance(stat_data[dd], np.float32): - sys_stat_new[dd] = frame_data[dd] - else: - pass - for key in sys_stat_new: - if isinstance(sys_stat_new[key], np.float32): - pass - elif sys_stat_new[key] is None or sys_stat_new[key][0] is None: - sys_stat_new[key] = None - elif isinstance(stat_data[dd], torch.Tensor): - sys_stat_new[key] = torch.cat(sys_stat_new[key], dim=0) - dict_to_device(sys_stat_new) - lst.append(sys_stat_new) + + all_elements = set() + if datasets and hasattr(datasets[0], "element_to_frames"): + all_elements.update(datasets[0].element_to_frames.keys()) + print("we want", all_elements) + + collected_elements = set() + for sys_stat in lst: + if "atype" in sys_stat: + collected_elements.update(np.unique(sys_stat["atype"].cpu().numpy())) + missing_elements = all_elements - collected_elements + + for missing_element in missing_elements: + for i, dataset in enumerate(datasets): + if hasattr(dataset, "element_to_frames"): + frame_indices = dataset.element_to_frames.get( + missing_element, [] + ) + for frame_idx in frame_indices: + if len(lst[i]["atype"]) >= nbatches: + break + frame_data = dataset[frame_idx] + for key in frame_data: + if key not in lst[i]: + lst[i][key] = [] + lst[i][key].append(frame_data[key]) + + collected_elements = set() + for sys_stat in lst: + if "atype" in sys_stat: + collected_elements.update( + np.unique(sys_stat["atype"].cpu().numpy()) + ) + + for sys_stat in lst: + for key in sys_stat: + if isinstance(sys_stat[key], list) and isinstance( + sys_stat[key][0], torch.Tensor + ): + sys_stat[key] = torch.cat(sys_stat[key], dim=0) return lst From 543a3181942a1fbadb6a59b0b48549dd31991d0e Mon Sep 17 00:00:00 2001 From: SumGuo Date: Thu, 26 Dec 2024 14:54:59 +0800 Subject: [PATCH 05/82] ll --- deepmd/pt/utils/dataset.py | 34 ++++++++------- deepmd/pt/utils/stat.py | 86 ++++++++++++++++++++------------------ 2 files changed, 64 insertions(+), 56 deletions(-) diff --git a/deepmd/pt/utils/dataset.py b/deepmd/pt/utils/dataset.py index d979224371..cd0e9074ff 100644 --- a/deepmd/pt/utils/dataset.py +++ b/deepmd/pt/utils/dataset.py @@ -13,14 +13,14 @@ DataRequirementItem, DeepmdData, ) - +import numpy as np class DeepmdDataSetForLoader(Dataset): def __init__(self, system: str, type_map: Optional[list[str]] = None) -> None: - """Construct DeePMD-style dataset containing frames cross different systems. + """Construct DeePMD-style dataset containing frames across different systems. Args: - - systems: Paths to systems. + - system: Path to the system. - type_map: Atom types. """ self.system = system @@ -40,18 +40,22 @@ def __getitem__(self, index): b_data = self._data_system.get_item_torch(index) b_data["natoms"] = self._natoms_vec return b_data - - def _build_element_to_frames(self): - """Mapping element types to frame indexes""" - element_to_frames = {element: [] for element in range(self._ntypes)} - for frame_idx in range(len(self)): - frame_data = self._data_system.get_item_torch(frame_idx) - - elements = frame_data["atype"] - for element in set(elements): - if len(element_to_frames[element]) < 10: - element_to_frames[element].append(frame_idx) - return element_to_frames + + def _build_element_to_frames(self): + """Build mapping from element types to frame indexes and return all unique element types.""" + element_to_frames = {element: [] for element in range(self._ntypes)} + all_elements = set() + all_frame_data = self._data_system.get_batch(self._data_system.nframes) + all_elements = np.unique(all_frame_data["type"]) + for i in range(len(self)): + for element in all_elements: + element_to_frames[element].append(i) + return element_to_frames, all_elements + + def get_frames_for_element(self, missing_element_name): + """Get the frames that contain the specified element type.""" + element_index = self._type_map.index(missing_element_name) + return self.element_to_frames.get(element_index, []) def add_data_requirement(self, data_requirement: list[DataRequirementItem]) -> None: """Add data requirement for this data system.""" diff --git a/deepmd/pt/utils/stat.py b/deepmd/pt/utils/stat.py index 1bf7811a5c..099ded21f4 100644 --- a/deepmd/pt/utils/stat.py +++ b/deepmd/pt/utils/stat.py @@ -37,7 +37,7 @@ log = logging.getLogger(__name__) def make_stat_input(datasets, dataloaders, nbatches): - """Pack data for statistics with all elements. + """Pack data for statistics. Args: - dataset: A list of dataset to analyze. @@ -81,46 +81,50 @@ def make_stat_input(datasets, dataloaders, nbatches): sys_stat[key] = torch.cat(sys_stat[key], dim=0) dict_to_device(sys_stat) lst.append(sys_stat) - - all_elements = set() - if datasets and hasattr(datasets[0], "element_to_frames"): - all_elements.update(datasets[0].element_to_frames.keys()) - print("we want", all_elements) - - collected_elements = set() - for sys_stat in lst: - if "atype" in sys_stat: - collected_elements.update(np.unique(sys_stat["atype"].cpu().numpy())) - missing_elements = all_elements - collected_elements - - for missing_element in missing_elements: - for i, dataset in enumerate(datasets): - if hasattr(dataset, "element_to_frames"): - frame_indices = dataset.element_to_frames.get( - missing_element, [] - ) - for frame_idx in frame_indices: - if len(lst[i]["atype"]) >= nbatches: - break - frame_data = dataset[frame_idx] - for key in frame_data: - if key not in lst[i]: - lst[i][key] = [] - lst[i][key].append(frame_data[key]) - - collected_elements = set() - for sys_stat in lst: - if "atype" in sys_stat: - collected_elements.update( - np.unique(sys_stat["atype"].cpu().numpy()) - ) - - for sys_stat in lst: - for key in sys_stat: - if isinstance(sys_stat[key], list) and isinstance( - sys_stat[key][0], torch.Tensor - ): - sys_stat[key] = torch.cat(sys_stat[key], dim=0) + unique_elements = set() + all_element = set() + + for i in lst: + unique_values = np.unique(i['atype'].cpu().numpy()) + unique_elements.update(unique_values) + for i in datasets: + all_elements_in_dataset = i.get_all_atype + all_element.update(all_elements_in_dataset) + print(all_element) + missing_element = all_element - unique_elements + for miss in missing_element: + for i in datasets: + if i.element_to_frames.get(miss, []) is not None: + frame_indices = i.element_to_frames.get(miss, []) + frame_data = i.__getitem__(frame_indices[0]) + break + else: + pass + sys_stat_new = {} + for dd in frame_data: + if dd == "type": + continue + if frame_data[dd] is None: + sys_stat_new[dd] = None + elif isinstance(frame_data[dd], np.ndarray): + if dd not in sys_stat_new: + sys_stat_new[dd] = [] + frame_data[dd] = torch.from_numpy(frame_data[dd]) + frame_data[dd] = frame_data[dd].unsqueeze(0) + sys_stat_new[dd].append(frame_data[dd]) + elif isinstance(stat_data[dd], np.float32): + sys_stat_new[dd] = frame_data[dd] + else: + pass + for key in sys_stat_new: + if isinstance(sys_stat_new[key], np.float32): + pass + elif sys_stat_new[key] is None or sys_stat_new[key][0] is None: + sys_stat_new[key] = None + elif isinstance(stat_data[dd], torch.Tensor): + sys_stat_new[key] = torch.cat(sys_stat_new[key], dim=0) + dict_to_device(sys_stat_new) + lst.append(sys_stat_new) return lst From ba723821996a81beaf8b68bcb29991d30d71c855 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 26 Dec 2024 06:56:29 +0000 Subject: [PATCH 06/82] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- deepmd/pt/utils/dataset.py | 15 +++++----- deepmd/pt/utils/stat.py | 8 ++++-- source/tests/pt/test_make_stat_input.py | 38 ++++++++++++++++--------- 3 files changed, 38 insertions(+), 23 deletions(-) diff --git a/deepmd/pt/utils/dataset.py b/deepmd/pt/utils/dataset.py index cd0e9074ff..045145a4fc 100644 --- a/deepmd/pt/utils/dataset.py +++ b/deepmd/pt/utils/dataset.py @@ -5,6 +5,7 @@ Optional, ) +import numpy as np from torch.utils.data import ( Dataset, ) @@ -13,7 +14,7 @@ DataRequirementItem, DeepmdData, ) -import numpy as np + class DeepmdDataSetForLoader(Dataset): def __init__(self, system: str, type_map: Optional[list[str]] = None) -> None: @@ -40,21 +41,21 @@ def __getitem__(self, index): b_data = self._data_system.get_item_torch(index) b_data["natoms"] = self._natoms_vec return b_data - + def _build_element_to_frames(self): """Build mapping from element types to frame indexes and return all unique element types.""" - element_to_frames = {element: [] for element in range(self._ntypes)} - all_elements = set() + element_to_frames = {element: [] for element in range(self._ntypes)} + all_elements = set() all_frame_data = self._data_system.get_batch(self._data_system.nframes) all_elements = np.unique(all_frame_data["type"]) - for i in range(len(self)): + for i in range(len(self)): for element in all_elements: element_to_frames[element].append(i) return element_to_frames, all_elements - + def get_frames_for_element(self, missing_element_name): """Get the frames that contain the specified element type.""" - element_index = self._type_map.index(missing_element_name) + element_index = self._type_map.index(missing_element_name) return self.element_to_frames.get(element_index, []) def add_data_requirement(self, data_requirement: list[DataRequirementItem]) -> None: diff --git a/deepmd/pt/utils/stat.py b/deepmd/pt/utils/stat.py index 099ded21f4..6013603913 100644 --- a/deepmd/pt/utils/stat.py +++ b/deepmd/pt/utils/stat.py @@ -36,6 +36,7 @@ log = logging.getLogger(__name__) + def make_stat_input(datasets, dataloaders, nbatches): """Pack data for statistics. @@ -85,7 +86,7 @@ def make_stat_input(datasets, dataloaders, nbatches): all_element = set() for i in lst: - unique_values = np.unique(i['atype'].cpu().numpy()) + unique_values = np.unique(i["atype"].cpu().numpy()) unique_elements.update(unique_values) for i in datasets: all_elements_in_dataset = i.get_all_atype @@ -102,7 +103,7 @@ def make_stat_input(datasets, dataloaders, nbatches): pass sys_stat_new = {} for dd in frame_data: - if dd == "type": + if dd == "type": continue if frame_data[dd] is None: sys_stat_new[dd] = None @@ -123,11 +124,12 @@ def make_stat_input(datasets, dataloaders, nbatches): sys_stat_new[key] = None elif isinstance(stat_data[dd], torch.Tensor): sys_stat_new[key] = torch.cat(sys_stat_new[key], dim=0) - dict_to_device(sys_stat_new) + dict_to_device(sys_stat_new) lst.append(sys_stat_new) return lst + def _restore_from_file( stat_file_path: DPPath, keys: list[str] = ["energy"], diff --git a/source/tests/pt/test_make_stat_input.py b/source/tests/pt/test_make_stat_input.py index fe2a4c9e51..2cd67193c2 100644 --- a/source/tests/pt/test_make_stat_input.py +++ b/source/tests/pt/test_make_stat_input.py @@ -1,15 +1,22 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later import unittest + import torch -from torch.utils.data import DataLoader -from deepmd.pt.utils.stat import make_stat_input +from torch.utils.data import ( + DataLoader, +) + +from deepmd.pt.utils.stat import ( + make_stat_input, +) + class TestDataset: def __init__(self, samples): - self.samples = samples self.element_to_frames = {} for idx, sample in enumerate(samples): - atypes = sample['atype'] + atypes = sample["atype"] for atype in atypes: if atype not in self.element_to_frames: self.element_to_frames[atype] = [] @@ -25,26 +32,31 @@ def __len__(self): def __getitem__(self, idx): sample = self.samples[idx] return { - 'atype': torch.tensor(sample['atype'], dtype=torch.long), - 'energy': torch.tensor(sample['energy'], dtype=torch.float32), + "atype": torch.tensor(sample["atype"], dtype=torch.long), + "energy": torch.tensor(sample["energy"], dtype=torch.float32), } + class TestMakeStatInput(unittest.TestCase): def setUp(self): - self.system = TestDataset([ - {'atype': [1], 'energy': -1.0}, - {'atype': [2], 'energy': -2.0}, - ]) + self.system = TestDataset( + [ + {"atype": [1], "energy": -1.0}, + {"atype": [2], "energy": -2.0}, + ] + ) self.datasets = [self.system] self.dataloaders = [ DataLoader(self.system, batch_size=1, shuffle=False), ] + def test_make_stat_input(self): - nbatches = 1 + nbatches = 1 lst = make_stat_input(self.datasets, self.dataloaders, nbatches=nbatches) all_elements = self.system.get_all_atype - unique_elements = {1,2} + unique_elements = {1, 2} self.assertEqual(unique_elements, all_elements, "make_stat_input miss elements") -if __name__ == '__main__': + +if __name__ == "__main__": unittest.main() From 26f9a171d211347c93d42baafe0dd276cd389f1e Mon Sep 17 00:00:00 2001 From: SumGuo Date: Thu, 26 Dec 2024 15:05:27 +0800 Subject: [PATCH 07/82] lll --- deepmd/pt/utils/stat.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/deepmd/pt/utils/stat.py b/deepmd/pt/utils/stat.py index 099ded21f4..f624e3d0bd 100644 --- a/deepmd/pt/utils/stat.py +++ b/deepmd/pt/utils/stat.py @@ -81,17 +81,16 @@ def make_stat_input(datasets, dataloaders, nbatches): sys_stat[key] = torch.cat(sys_stat[key], dim=0) dict_to_device(sys_stat) lst.append(sys_stat) - unique_elements = set() - all_element = set() + collect_elements = set() + all_element = set() for i in lst: - unique_values = np.unique(i['atype'].cpu().numpy()) - unique_elements.update(unique_values) + collect_values = np.unique(i['atype'].cpu().numpy()) + collect_elements.update(collect_values) for i in datasets: all_elements_in_dataset = i.get_all_atype all_element.update(all_elements_in_dataset) - print(all_element) - missing_element = all_element - unique_elements + missing_element = all_element - collect_elements for miss in missing_element: for i in datasets: if i.element_to_frames.get(miss, []) is not None: From dc6430730dd18087d0b6fafd98c5e4add795a57a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 26 Dec 2024 07:07:55 +0000 Subject: [PATCH 08/82] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- deepmd/pt/utils/stat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepmd/pt/utils/stat.py b/deepmd/pt/utils/stat.py index ff40f6cc3d..7799b7c666 100644 --- a/deepmd/pt/utils/stat.py +++ b/deepmd/pt/utils/stat.py @@ -86,7 +86,7 @@ def make_stat_input(datasets, dataloaders, nbatches): collect_elements = set() all_element = set() for i in lst: - collect_values = np.unique(i['atype'].cpu().numpy()) + collect_values = np.unique(i["atype"].cpu().numpy()) collect_elements.update(collect_values) for i in datasets: all_elements_in_dataset = i.get_all_atype From 8f962b57957cd0e84ccdb4f4251b341b80624710 Mon Sep 17 00:00:00 2001 From: SumGuo Date: Thu, 2 Jan 2025 10:37:20 +0800 Subject: [PATCH 09/82] allchange --- deepmd/pt/train/training.py | 2 + deepmd/pt/utils/dataset.py | 40 +++++--- deepmd/pt/utils/stat.py | 198 ++++++++++++++++++++++-------------- deepmd/utils/argcheck.py | 6 ++ deepmd/utils/data.py | 16 +-- 5 files changed, 162 insertions(+), 100 deletions(-) diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index f74c4769bf..cfba070277 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -142,6 +142,7 @@ def __init__( self.max_ckpt_keep = training_params.get("max_ckpt_keep", 5) self.display_in_training = training_params.get("disp_training", True) self.timing_in_training = training_params.get("time_training", True) + self.min_frames_per_element_forstat = training_params.get("min_frames_per_element_forstat", 10) self.change_bias_after_training = training_params.get( "change_bias_after_training", False ) @@ -226,6 +227,7 @@ def get_sample(): _training_data.systems, _training_data.dataloaders, _data_stat_nbatch, + self.min_frames_per_element_forstat, ) return sampled diff --git a/deepmd/pt/utils/dataset.py b/deepmd/pt/utils/dataset.py index 045145a4fc..fb63388973 100644 --- a/deepmd/pt/utils/dataset.py +++ b/deepmd/pt/utils/dataset.py @@ -6,6 +6,9 @@ ) import numpy as np +import glob +import os +from collections import defaultdict from torch.utils.data import ( Dataset, ) @@ -15,6 +18,10 @@ DeepmdData, ) +from deepmd.utils.path import ( + DPPath, +) + class DeepmdDataSetForLoader(Dataset): def __init__(self, system: str, type_map: Optional[list[str]] = None) -> None: @@ -31,7 +38,6 @@ def __init__(self, system: str, type_map: Optional[list[str]] = None) -> None: self._ntypes = self._data_system.get_ntypes() self._natoms = self._data_system.get_natoms() self._natoms_vec = self._data_system.get_natoms_vec(self._ntypes) - self.element_to_frames, self.get_all_atype = self._build_element_to_frames() def __len__(self) -> int: return self._data_system.nframes @@ -42,21 +48,23 @@ def __getitem__(self, index): b_data["natoms"] = self._natoms_vec return b_data - def _build_element_to_frames(self): - """Build mapping from element types to frame indexes and return all unique element types.""" - element_to_frames = {element: [] for element in range(self._ntypes)} - all_elements = set() - all_frame_data = self._data_system.get_batch(self._data_system.nframes) - all_elements = np.unique(all_frame_data["type"]) - for i in range(len(self)): - for element in all_elements: - element_to_frames[element].append(i) - return element_to_frames, all_elements - - def get_frames_for_element(self, missing_element_name): - """Get the frames that contain the specified element type.""" - element_index = self._type_map.index(missing_element_name) - return self.element_to_frames.get(element_index, []) + def true_types(self): + """Identify and count unique element types present in the dataset, + and count the number of frames each element appears in.""" + element_counts = defaultdict(lambda: {"count": 0, "frames": 0}) + set_pattern = os.path.join(self.system, "set.*") + set_files = sorted(glob.glob(set_pattern)) + for set_file in set_files: + element_data = self._data_system._load_type_mix(DPPath(set_file)) + unique_elements, counts = np.unique(element_data, return_counts=True) + for elem, cnt in zip(unique_elements, counts): + element_counts[elem]["count"] += cnt + for elem in unique_elements: + frames_with_elem = np.any(element_data == elem, axis=1) + row_count = np.sum(frames_with_elem) + element_counts[elem]["frames"] += row_count + element_counts = dict(element_counts) + return element_counts def add_data_requirement(self, data_requirement: list[DataRequirementItem]) -> None: """Add data requirement for this data system.""" diff --git a/deepmd/pt/utils/stat.py b/deepmd/pt/utils/stat.py index ff40f6cc3d..eb3d54b0e7 100644 --- a/deepmd/pt/utils/stat.py +++ b/deepmd/pt/utils/stat.py @@ -36,8 +36,7 @@ log = logging.getLogger(__name__) - -def make_stat_input(datasets, dataloaders, nbatches): +def make_stat_input(datasets, dataloaders, nbatches, min_frames_per_element_forstat): """Pack data for statistics. Args: @@ -50,82 +49,129 @@ def make_stat_input(datasets, dataloaders, nbatches): """ lst = [] log.info(f"Packing data for statistics from {len(datasets)} systems") - for i in range(len(datasets)): - sys_stat = {} - with torch.device("cpu"): - iterator = iter(dataloaders[i]) - numb_batches = min(nbatches, len(dataloaders[i])) - for _ in range(numb_batches): - try: - stat_data = next(iterator) - except StopIteration: - iterator = iter(dataloaders[i]) - stat_data = next(iterator) - for dd in stat_data: - if stat_data[dd] is None: - sys_stat[dd] = None - elif isinstance(stat_data[dd], torch.Tensor): - if dd not in sys_stat: - sys_stat[dd] = [] - sys_stat[dd].append(stat_data[dd]) - elif isinstance(stat_data[dd], np.float32): - sys_stat[dd] = stat_data[dd] + collect_elements = set() + total_element_types = set() + total_element_counts = {} + if datasets[0].mixed_type: + for sys_index, (dataset, dataloader) in enumerate(zip(datasets, dataloaders)): + sys_stat = {} + with torch.device("cpu"): + iterator = iter(dataloader) + numb_batches = min(nbatches, len(dataloader)) + for _ in range(numb_batches): + try: + stat_data = next(iterator) + except StopIteration: + iterator = iter(dataloader) + stat_data = next(iterator) + for dd in stat_data: + if stat_data[dd] is None: + sys_stat[dd] = None + elif isinstance(stat_data[dd], torch.Tensor): + if dd not in sys_stat: + sys_stat[dd] = [] + sys_stat[dd].append(stat_data[dd]) + elif isinstance(stat_data[dd], np.float32): + sys_stat[dd] = stat_data[dd] + else: + pass + if 'atype' in sys_stat and isinstance(sys_stat['atype'], list): + collect_values = np.unique(torch.cat(sys_stat['atype']).numpy()) + collect_elements.update(collect_values) + + for key in sys_stat: + if isinstance(sys_stat[key], np.float32): + pass + elif sys_stat[key] is None or (isinstance(sys_stat[key], list) and (len(sys_stat[key]) == 0 or sys_stat[key][0] is None)): + sys_stat[key] = None + elif isinstance(sys_stat[key][0], torch.Tensor): + sys_stat[key] = torch.cat(sys_stat[key], dim=0) + dict_to_device(sys_stat) + lst.append(sys_stat) + + element_counts = dataset.true_types() + for elem, data in element_counts.items(): + count = data["count"] + frames = data["frames"] + total_element_types.add(elem) + if elem not in total_element_counts: + total_element_counts[elem] = {"count": 0, "frames": 0, "indices": []} + total_element_counts[elem]["count"] += count + if len(total_element_counts[elem]["indices"]) < min_frames_per_element_forstat: + total_element_counts[elem]["indices"].append({ + "sys_index": sys_index, + "frames": frames + }) + for elem, data in total_element_counts.items(): + count = data["count"] + indices_count = len(data["indices"]) + if indices_count < min_frames_per_element_forstat: + log.warning(f'The number of frame with element {elem} is {indices_count}, which is less than the expected maximum value {min_frames_per_element_forstat}') + missing_elements = total_element_types - collect_elements + for miss in missing_elements: + sys_indices = total_element_counts[miss].get('indices', []) + for sys_info in sys_indices: + sys_index = sys_info['sys_index'] + frames = sys_info['frames'] + sys = datasets[sys_index] + frame_data = sys.__getitem__(frames) + sys_stat_new = {} + for dd in frame_data: + if dd == "type": + continue + if frame_data[dd] is None: + sys_stat_new[dd] = None + elif isinstance(frame_data[dd], np.ndarray): + if dd not in sys_stat_new: + sys_stat_new[dd] = [] + frame_data[dd] = torch.from_numpy(frame_data[dd]) + frame_data[dd] = frame_data[dd].unsqueeze(0) + sys_stat_new[dd].append(frame_data[dd]) + elif isinstance(frame_data[dd], np.float32): + sys_stat_new[dd] = frame_data[dd] else: pass - - for key in sys_stat: - if isinstance(sys_stat[key], np.float32): - pass - elif sys_stat[key] is None or sys_stat[key][0] is None: - sys_stat[key] = None - elif isinstance(stat_data[dd], torch.Tensor): - sys_stat[key] = torch.cat(sys_stat[key], dim=0) - dict_to_device(sys_stat) - lst.append(sys_stat) - - collect_elements = set() - all_element = set() - for i in lst: - collect_values = np.unique(i['atype'].cpu().numpy()) - collect_elements.update(collect_values) - for i in datasets: - all_elements_in_dataset = i.get_all_atype - all_element.update(all_elements_in_dataset) - missing_element = all_element - collect_elements - for miss in missing_element: - for i in datasets: - if i.element_to_frames.get(miss, []) is not None: - frame_indices = i.element_to_frames.get(miss, []) - frame_data = i.__getitem__(frame_indices[0]) - break - else: - pass - sys_stat_new = {} - for dd in frame_data: - if dd == "type": - continue - if frame_data[dd] is None: - sys_stat_new[dd] = None - elif isinstance(frame_data[dd], np.ndarray): - if dd not in sys_stat_new: - sys_stat_new[dd] = [] - frame_data[dd] = torch.from_numpy(frame_data[dd]) - frame_data[dd] = frame_data[dd].unsqueeze(0) - sys_stat_new[dd].append(frame_data[dd]) - elif isinstance(stat_data[dd], np.float32): - sys_stat_new[dd] = frame_data[dd] - else: - pass - for key in sys_stat_new: - if isinstance(sys_stat_new[key], np.float32): - pass - elif sys_stat_new[key] is None or sys_stat_new[key][0] is None: - sys_stat_new[key] = None - elif isinstance(stat_data[dd], torch.Tensor): - sys_stat_new[key] = torch.cat(sys_stat_new[key], dim=0) - dict_to_device(sys_stat_new) - lst.append(sys_stat_new) - + for key in sys_stat_new: + if isinstance(sys_stat_new[key], np.float32): + pass + elif sys_stat_new[key] is None or sys_stat_new[key][0] is None: + sys_stat_new[key] = None + elif isinstance(frame_data[dd], torch.Tensor): + sys_stat_new[key] = torch.cat(sys_stat_new[key], dim=0) + dict_to_device(sys_stat_new) + lst.append(sys_stat_new) + else: + for i in range(len(datasets)): + sys_stat = {} + with torch.device("cpu"): + iterator = iter(dataloaders[i]) + numb_batches = min(nbatches, len(dataloaders[i])) + for _ in range(numb_batches): + try: + stat_data = next(iterator) + except StopIteration: + iterator = iter(dataloaders[i]) + stat_data = next(iterator) + for dd in stat_data: + if stat_data[dd] is None: + sys_stat[dd] = None + elif isinstance(stat_data[dd], torch.Tensor): + if dd not in sys_stat: + sys_stat[dd] = [] + sys_stat[dd].append(stat_data[dd]) + elif isinstance(stat_data[dd], np.float32): + sys_stat[dd] = stat_data[dd] + else: + pass + for key in sys_stat: + if isinstance(sys_stat[key], np.float32): + pass + elif sys_stat[key] is None or sys_stat[key][0] is None: + sys_stat[key] = None + elif isinstance(stat_data[dd], torch.Tensor): + sys_stat[key] = torch.cat(sys_stat[key], dim=0) + dict_to_device(sys_stat) + lst.append(sys_stat) return lst diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index d5419a38cd..5b1fb05f2e 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -2826,6 +2826,12 @@ def training_args( optional=True, doc=doc_only_pt_supported + doc_gradient_max_norm, ), + Argument( + "min_frames_per_element_forstat", + int, + optional=True, + doc="The minimum number of frames per element used for statistics.", + ), ] variants = [ Variant( diff --git a/deepmd/utils/data.py b/deepmd/utils/data.py index 493a9d8d54..2afe20acd8 100644 --- a/deepmd/utils/data.py +++ b/deepmd/utils/data.py @@ -530,14 +530,6 @@ def _load_set(self, set_name: DPPath): if self.mixed_type: # nframes x natoms atom_type_mix = self._load_type_mix(set_name) - if self.enforce_type_map: - try: - atom_type_mix_ = self.type_idx_map[atom_type_mix].astype(np.int32) - except IndexError as e: - raise IndexError( - f"some types in 'real_atom_types.npy' of set {set_name} are not contained in {self.get_ntypes()} types!" - ) from e - atom_type_mix = atom_type_mix_ real_type = atom_type_mix.reshape([nframes, self.natoms]) data["type"] = real_type natoms = data["type"].shape[1] @@ -672,6 +664,14 @@ def _load_type(self, sys_path: DPPath): def _load_type_mix(self, set_name: DPPath): type_path = set_name / "real_atom_types.npy" real_type = type_path.load_numpy().astype(np.int32).reshape([-1, self.natoms]) + if self.enforce_type_map: + try: + atom_type_mix_ = self.type_idx_map[real_type].astype(np.int32) + except IndexError as e: + raise IndexError( + f"some types in 'real_atom_types.npy' of set {set_name} are not contained in {self.get_ntypes()} types!" + ) from e + real_type = atom_type_mix_ return real_type def _make_idx_map(self, atom_type): From f57498dd60b6c14a7987b56cafdac5828f24ae40 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 2 Jan 2025 02:40:17 +0000 Subject: [PATCH 10/82] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- deepmd/pt/train/training.py | 4 +++- deepmd/pt/utils/dataset.py | 12 +++++++----- deepmd/pt/utils/stat.py | 8 ++++---- 3 files changed, 14 insertions(+), 10 deletions(-) diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index cfba070277..243bf5af0e 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -142,7 +142,9 @@ def __init__( self.max_ckpt_keep = training_params.get("max_ckpt_keep", 5) self.display_in_training = training_params.get("disp_training", True) self.timing_in_training = training_params.get("time_training", True) - self.min_frames_per_element_forstat = training_params.get("min_frames_per_element_forstat", 10) + self.min_frames_per_element_forstat = training_params.get( + "min_frames_per_element_forstat", 10 + ) self.change_bias_after_training = training_params.get( "change_bias_after_training", False ) diff --git a/deepmd/pt/utils/dataset.py b/deepmd/pt/utils/dataset.py index fb63388973..042714f5af 100644 --- a/deepmd/pt/utils/dataset.py +++ b/deepmd/pt/utils/dataset.py @@ -1,14 +1,16 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import glob +import os +from collections import ( + defaultdict, +) from typing import ( Optional, ) import numpy as np -import glob -import os -from collections import defaultdict from torch.utils.data import ( Dataset, ) @@ -17,7 +19,6 @@ DataRequirementItem, DeepmdData, ) - from deepmd.utils.path import ( DPPath, ) @@ -50,7 +51,8 @@ def __getitem__(self, index): def true_types(self): """Identify and count unique element types present in the dataset, - and count the number of frames each element appears in.""" + and count the number of frames each element appears in. + """ element_counts = defaultdict(lambda: {"count": 0, "frames": 0}) set_pattern = os.path.join(self.system, "set.*") set_files = sorted(glob.glob(set_pattern)) diff --git a/deepmd/pt/utils/stat.py b/deepmd/pt/utils/stat.py index f27a2a106c..aafe97d4cc 100644 --- a/deepmd/pt/utils/stat.py +++ b/deepmd/pt/utils/stat.py @@ -101,7 +101,7 @@ def make_stat_input(datasets, dataloaders, nbatches, min_frames_per_element_fors total_element_counts[elem]["indices"].append({ "sys_index": sys_index, "frames": frames - }) + }) for elem, data in total_element_counts.items(): count = data["count"] indices_count = len(data["indices"]) @@ -111,8 +111,8 @@ def make_stat_input(datasets, dataloaders, nbatches, min_frames_per_element_fors for miss in missing_elements: sys_indices = total_element_counts[miss].get('indices', []) for sys_info in sys_indices: - sys_index = sys_info['sys_index'] - frames = sys_info['frames'] + sys_index = sys_info['sys_index'] + frames = sys_info['frames'] sys = datasets[sys_index] frame_data = sys.__getitem__(frames) sys_stat_new = {} @@ -172,7 +172,7 @@ def make_stat_input(datasets, dataloaders, nbatches, min_frames_per_element_fors elif isinstance(stat_data[dd], torch.Tensor): sys_stat[key] = torch.cat(sys_stat[key], dim=0) dict_to_device(sys_stat) - lst.append(sys_stat) + lst.append(sys_stat) ======= for key in sys_stat: From faeb7c5693dccad6667a98edc3fbc3cec7e4d82e Mon Sep 17 00:00:00 2001 From: SumGuo Date: Thu, 2 Jan 2025 10:52:40 +0800 Subject: [PATCH 11/82] test --- source/tests/pt/test_make_stat_input.py | 47 ++++++++++++++++++------- 1 file changed, 34 insertions(+), 13 deletions(-) diff --git a/source/tests/pt/test_make_stat_input.py b/source/tests/pt/test_make_stat_input.py index 2cd67193c2..0ab49cf243 100644 --- a/source/tests/pt/test_make_stat_input.py +++ b/source/tests/pt/test_make_stat_input.py @@ -1,25 +1,21 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import unittest - import torch -from torch.utils.data import ( - DataLoader, -) - -from deepmd.pt.utils.stat import ( - make_stat_input, -) - +from torch.utils.data import DataLoader +from deepmd.pt.utils.stat import make_stat_input +import numpy as np +import os +import glob +from collections import defaultdict class TestDataset: def __init__(self, samples): self.samples = samples - self.element_to_frames = {} + self.element_to_frames = defaultdict(list) + self.mixed_type = True for idx, sample in enumerate(samples): atypes = sample["atype"] for atype in atypes: - if atype not in self.element_to_frames: - self.element_to_frames[atype] = [] self.element_to_frames[atype].append(idx) @property @@ -36,6 +32,17 @@ def __getitem__(self, idx): "energy": torch.tensor(sample["energy"], dtype=torch.float32), } + def true_types(self): + element_counts = defaultdict(lambda: {"count": 0, "frames": 0}) + for idx, sample in enumerate(self.samples): + atypes = sample["atype"] + unique_atypes = set(atypes) + for atype in atypes: + element_counts[atype]["count"] += 1 + for atype in unique_atypes: + element_counts[atype]["frames"] += 1 + return dict(element_counts) + class TestMakeStatInput(unittest.TestCase): def setUp(self): @@ -52,11 +59,25 @@ def setUp(self): def test_make_stat_input(self): nbatches = 1 - lst = make_stat_input(self.datasets, self.dataloaders, nbatches=nbatches) + lst = make_stat_input( + self.datasets, + self.dataloaders, + nbatches=nbatches, + min_frames_per_element_forstat=1, + ) all_elements = self.system.get_all_atype unique_elements = {1, 2} self.assertEqual(unique_elements, all_elements, "make_stat_input miss elements") + expected_true_types = { + 1: {"count": 1, "frames": 1}, + 2: {"count": 1, "frames": 1}, + } + actual_true_types = self.system.true_types() + self.assertEqual( + expected_true_types, actual_true_types, "true_types is wrong" + ) + if __name__ == "__main__": unittest.main() From 725f1dd1b3d079b6a63c8ae9b98072a70777d5df Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 2 Jan 2025 02:54:05 +0000 Subject: [PATCH 12/82] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- source/tests/pt/test_make_stat_input.py | 30 ++++++++++++++----------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/source/tests/pt/test_make_stat_input.py b/source/tests/pt/test_make_stat_input.py index 0ab49cf243..4bdfc26732 100644 --- a/source/tests/pt/test_make_stat_input.py +++ b/source/tests/pt/test_make_stat_input.py @@ -1,12 +1,18 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import unittest +from collections import ( + defaultdict, +) + import torch -from torch.utils.data import DataLoader -from deepmd.pt.utils.stat import make_stat_input -import numpy as np -import os -import glob -from collections import defaultdict +from torch.utils.data import ( + DataLoader, +) + +from deepmd.pt.utils.stat import ( + make_stat_input, +) + class TestDataset: def __init__(self, samples): @@ -38,9 +44,9 @@ def true_types(self): atypes = sample["atype"] unique_atypes = set(atypes) for atype in atypes: - element_counts[atype]["count"] += 1 + element_counts[atype]["count"] += 1 for atype in unique_atypes: - element_counts[atype]["frames"] += 1 + element_counts[atype]["frames"] += 1 return dict(element_counts) @@ -70,13 +76,11 @@ def test_make_stat_input(self): self.assertEqual(unique_elements, all_elements, "make_stat_input miss elements") expected_true_types = { - 1: {"count": 1, "frames": 1}, - 2: {"count": 1, "frames": 1}, + 1: {"count": 1, "frames": 1}, + 2: {"count": 1, "frames": 1}, } actual_true_types = self.system.true_types() - self.assertEqual( - expected_true_types, actual_true_types, "true_types is wrong" - ) + self.assertEqual(expected_true_types, actual_true_types, "true_types is wrong") if __name__ == "__main__": From ca7fc84624b9da01b2afe5f7aeed61ab9d63a4e3 Mon Sep 17 00:00:00 2001 From: SumGuo Date: Thu, 2 Jan 2025 11:02:50 +0800 Subject: [PATCH 13/82] stat --- deepmd/pt/utils/stat.py | 68 +++-------------------------------------- 1 file changed, 4 insertions(+), 64 deletions(-) diff --git a/deepmd/pt/utils/stat.py b/deepmd/pt/utils/stat.py index aafe97d4cc..6c9c410bcd 100644 --- a/deepmd/pt/utils/stat.py +++ b/deepmd/pt/utils/stat.py @@ -78,7 +78,6 @@ def make_stat_input(datasets, dataloaders, nbatches, min_frames_per_element_fors if 'atype' in sys_stat and isinstance(sys_stat['atype'], list): collect_values = np.unique(torch.cat(sys_stat['atype']).numpy()) collect_elements.update(collect_values) - for key in sys_stat: if isinstance(sys_stat[key], np.float32): pass @@ -88,7 +87,6 @@ def make_stat_input(datasets, dataloaders, nbatches, min_frames_per_element_fors sys_stat[key] = torch.cat(sys_stat[key], dim=0) dict_to_device(sys_stat) lst.append(sys_stat) - element_counts = dataset.true_types() for elem, data in element_counts.items(): count = data["count"] @@ -101,7 +99,7 @@ def make_stat_input(datasets, dataloaders, nbatches, min_frames_per_element_fors total_element_counts[elem]["indices"].append({ "sys_index": sys_index, "frames": frames - }) + }) for elem, data in total_element_counts.items(): count = data["count"] indices_count = len(data["indices"]) @@ -111,8 +109,8 @@ def make_stat_input(datasets, dataloaders, nbatches, min_frames_per_element_fors for miss in missing_elements: sys_indices = total_element_counts[miss].get('indices', []) for sys_info in sys_indices: - sys_index = sys_info['sys_index'] - frames = sys_info['frames'] + sys_index = sys_info['sys_index'] + frames = sys_info['frames'] sys = datasets[sys_index] frame_data = sys.__getitem__(frames) sys_stat_new = {} @@ -131,7 +129,6 @@ def make_stat_input(datasets, dataloaders, nbatches, min_frames_per_element_fors sys_stat_new[dd] = frame_data[dd] else: pass -<<<<<<< HEAD for key in sys_stat_new: if isinstance(sys_stat_new[key], np.float32): pass @@ -172,66 +169,9 @@ def make_stat_input(datasets, dataloaders, nbatches, min_frames_per_element_fors elif isinstance(stat_data[dd], torch.Tensor): sys_stat[key] = torch.cat(sys_stat[key], dim=0) dict_to_device(sys_stat) - lst.append(sys_stat) -======= - - for key in sys_stat: - if isinstance(sys_stat[key], np.float32): - pass - elif sys_stat[key] is None or sys_stat[key][0] is None: - sys_stat[key] = None - elif isinstance(stat_data[dd], torch.Tensor): - sys_stat[key] = torch.cat(sys_stat[key], dim=0) - dict_to_device(sys_stat) - lst.append(sys_stat) - - collect_elements = set() - all_element = set() - for i in lst: - collect_values = np.unique(i["atype"].cpu().numpy()) - collect_elements.update(collect_values) - for i in datasets: - all_elements_in_dataset = i.get_all_atype - all_element.update(all_elements_in_dataset) - missing_element = all_element - collect_elements - for miss in missing_element: - for i in datasets: - if i.element_to_frames.get(miss, []) is not None: - frame_indices = i.element_to_frames.get(miss, []) - frame_data = i.__getitem__(frame_indices[0]) - break - else: - pass - sys_stat_new = {} - for dd in frame_data: - if dd == "type": - continue - if frame_data[dd] is None: - sys_stat_new[dd] = None - elif isinstance(frame_data[dd], np.ndarray): - if dd not in sys_stat_new: - sys_stat_new[dd] = [] - frame_data[dd] = torch.from_numpy(frame_data[dd]) - frame_data[dd] = frame_data[dd].unsqueeze(0) - sys_stat_new[dd].append(frame_data[dd]) - elif isinstance(stat_data[dd], np.float32): - sys_stat_new[dd] = frame_data[dd] - else: - pass - for key in sys_stat_new: - if isinstance(sys_stat_new[key], np.float32): - pass - elif sys_stat_new[key] is None or sys_stat_new[key][0] is None: - sys_stat_new[key] = None - elif isinstance(stat_data[dd], torch.Tensor): - sys_stat_new[key] = torch.cat(sys_stat_new[key], dim=0) - dict_to_device(sys_stat_new) - lst.append(sys_stat_new) - ->>>>>>> dc6430730dd18087d0b6fafd98c5e4add795a57a + lst.append(sys_stat) return lst - def _restore_from_file( stat_file_path: DPPath, keys: list[str] = ["energy"], From 05128d363c7b306f49d6f3a4be29227e13719ef9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 2 Jan 2025 03:04:17 +0000 Subject: [PATCH 14/82] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- deepmd/pt/utils/stat.py | 41 +++++++++++++++++++++++++++-------------- 1 file changed, 27 insertions(+), 14 deletions(-) diff --git a/deepmd/pt/utils/stat.py b/deepmd/pt/utils/stat.py index 6c9c410bcd..682849bdec 100644 --- a/deepmd/pt/utils/stat.py +++ b/deepmd/pt/utils/stat.py @@ -36,6 +36,7 @@ log = logging.getLogger(__name__) + def make_stat_input(datasets, dataloaders, nbatches, min_frames_per_element_forstat): """Pack data for statistics. @@ -75,13 +76,16 @@ def make_stat_input(datasets, dataloaders, nbatches, min_frames_per_element_fors sys_stat[dd] = stat_data[dd] else: pass - if 'atype' in sys_stat and isinstance(sys_stat['atype'], list): - collect_values = np.unique(torch.cat(sys_stat['atype']).numpy()) + if "atype" in sys_stat and isinstance(sys_stat["atype"], list): + collect_values = np.unique(torch.cat(sys_stat["atype"]).numpy()) collect_elements.update(collect_values) for key in sys_stat: if isinstance(sys_stat[key], np.float32): pass - elif sys_stat[key] is None or (isinstance(sys_stat[key], list) and (len(sys_stat[key]) == 0 or sys_stat[key][0] is None)): + elif sys_stat[key] is None or ( + isinstance(sys_stat[key], list) + and (len(sys_stat[key]) == 0 or sys_stat[key][0] is None) + ): sys_stat[key] = None elif isinstance(sys_stat[key][0], torch.Tensor): sys_stat[key] = torch.cat(sys_stat[key], dim=0) @@ -93,24 +97,32 @@ def make_stat_input(datasets, dataloaders, nbatches, min_frames_per_element_fors frames = data["frames"] total_element_types.add(elem) if elem not in total_element_counts: - total_element_counts[elem] = {"count": 0, "frames": 0, "indices": []} + total_element_counts[elem] = { + "count": 0, + "frames": 0, + "indices": [], + } total_element_counts[elem]["count"] += count - if len(total_element_counts[elem]["indices"]) < min_frames_per_element_forstat: - total_element_counts[elem]["indices"].append({ - "sys_index": sys_index, - "frames": frames - }) + if ( + len(total_element_counts[elem]["indices"]) + < min_frames_per_element_forstat + ): + total_element_counts[elem]["indices"].append( + {"sys_index": sys_index, "frames": frames} + ) for elem, data in total_element_counts.items(): count = data["count"] indices_count = len(data["indices"]) if indices_count < min_frames_per_element_forstat: - log.warning(f'The number of frame with element {elem} is {indices_count}, which is less than the expected maximum value {min_frames_per_element_forstat}') + log.warning( + f"The number of frame with element {elem} is {indices_count}, which is less than the expected maximum value {min_frames_per_element_forstat}" + ) missing_elements = total_element_types - collect_elements for miss in missing_elements: - sys_indices = total_element_counts[miss].get('indices', []) + sys_indices = total_element_counts[miss].get("indices", []) for sys_info in sys_indices: - sys_index = sys_info['sys_index'] - frames = sys_info['frames'] + sys_index = sys_info["sys_index"] + frames = sys_info["frames"] sys = datasets[sys_index] frame_data = sys.__getitem__(frames) sys_stat_new = {} @@ -169,9 +181,10 @@ def make_stat_input(datasets, dataloaders, nbatches, min_frames_per_element_fors elif isinstance(stat_data[dd], torch.Tensor): sys_stat[key] = torch.cat(sys_stat[key], dim=0) dict_to_device(sys_stat) - lst.append(sys_stat) + lst.append(sys_stat) return lst + def _restore_from_file( stat_file_path: DPPath, keys: list[str] = ["energy"], From 48286199d9f079495cc870a26889fc4fd6e749fb Mon Sep 17 00:00:00 2001 From: SumGuo Date: Thu, 2 Jan 2025 19:37:10 +0800 Subject: [PATCH 15/82] check --- deepmd/pt/train/training.py | 2 + deepmd/pt/utils/dataset.py | 40 +++---- deepmd/pt/utils/stat.py | 210 ++++++++++++++++++------------------ deepmd/utils/argcheck.py | 7 ++ 4 files changed, 133 insertions(+), 126 deletions(-) diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index 243bf5af0e..e6e3276216 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -145,6 +145,7 @@ def __init__( self.min_frames_per_element_forstat = training_params.get( "min_frames_per_element_forstat", 10 ) + self.enable_element_completion = training_params.get("enable_element_completion", True) self.change_bias_after_training = training_params.get( "change_bias_after_training", False ) @@ -230,6 +231,7 @@ def get_sample(): _training_data.dataloaders, _data_stat_nbatch, self.min_frames_per_element_forstat, + self.enable_element_completion, ) return sampled diff --git a/deepmd/pt/utils/dataset.py b/deepmd/pt/utils/dataset.py index 042714f5af..2f51861799 100644 --- a/deepmd/pt/utils/dataset.py +++ b/deepmd/pt/utils/dataset.py @@ -1,8 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later - -import glob -import os from collections import ( defaultdict, ) @@ -19,9 +16,6 @@ DataRequirementItem, DeepmdData, ) -from deepmd.utils.path import ( - DPPath, -) class DeepmdDataSetForLoader(Dataset): @@ -49,22 +43,32 @@ def __getitem__(self, index): b_data["natoms"] = self._natoms_vec return b_data - def true_types(self): - """Identify and count unique element types present in the dataset, - and count the number of frames each element appears in. + def get_frame_index(self): + """ + Get the frame index and the number of frames with all the elements in the system. + This function is only used in the mixed type. + + Returns + ------- + element_counts : dict + A dictionary where: + - The key is the element type. + - The value is another dictionary with the following keys: + - "frames": int + The total number of frames in which the element appears. + - "indices": list of int + A list of row indices where the element is found in the dataset. """ - element_counts = defaultdict(lambda: {"count": 0, "frames": 0}) - set_pattern = os.path.join(self.system, "set.*") - set_files = sorted(glob.glob(set_pattern)) + element_counts = defaultdict(lambda: {"frames": 0, "indices": []}) + set_files = self._data_system.dirs for set_file in set_files: - element_data = self._data_system._load_type_mix(DPPath(set_file)) - unique_elements, counts = np.unique(element_data, return_counts=True) - for elem, cnt in zip(unique_elements, counts): - element_counts[elem]["count"] += cnt + element_data = self._data_system._load_type_mix(set_file) + unique_elements = np.unique(element_data) for elem in unique_elements: frames_with_elem = np.any(element_data == elem, axis=1) - row_count = np.sum(frames_with_elem) - element_counts[elem]["frames"] += row_count + row_indices = np.where(frames_with_elem)[0] + element_counts[elem]["frames"] += len(row_indices) + element_counts[elem]["indices"].extend(row_indices.tolist()) element_counts = dict(element_counts) return element_counts diff --git a/deepmd/pt/utils/stat.py b/deepmd/pt/utils/stat.py index 6c9c410bcd..5119e4179f 100644 --- a/deepmd/pt/utils/stat.py +++ b/deepmd/pt/utils/stat.py @@ -36,140 +36,134 @@ log = logging.getLogger(__name__) -def make_stat_input(datasets, dataloaders, nbatches, min_frames_per_element_forstat): +def make_stat_input(datasets, dataloaders, nbatches, min_frames_per_element_forstat, enable_element_completion=True): """Pack data for statistics. Args: - - dataset: A list of dataset to analyze. + - datasets: A list of datasets to analyze. + - dataloaders: Corresponding dataloaders for the datasets. - nbatches: Batch count for collecting stats. + - min_frames_per_element_forstat: Minimum frames required for statistics. + - enable_element_completion: Whether to perform missing element completion (default: True). Returns ------- - - a list of dicts, each of which contains data from a system + - A list of dicts, each of which contains data from a system. """ lst = [] log.info(f"Packing data for statistics from {len(datasets)} systems") collect_elements = set() total_element_types = set() - total_element_counts = {} + global_element_counts = {} if datasets[0].mixed_type: - for sys_index, (dataset, dataloader) in enumerate(zip(datasets, dataloaders)): - sys_stat = {} - with torch.device("cpu"): + if enable_element_completion: + log.info( + f'Element check enabled. ' + f'Verifying if frames with elements meet the set of {min_frames_per_element_forstat}.' + ) + else: + log.info("Element completion is disabled. Skipping missing element handling.") + + def process_batches(dataloader, sys_stat): + """Process batches from a dataloader to collect statistics.""" + iterator = iter(dataloader) + numb_batches = min(nbatches, len(dataloader)) + for _ in range(numb_batches): + try: + stat_data = next(iterator) + except StopIteration: iterator = iter(dataloader) - numb_batches = min(nbatches, len(dataloader)) - for _ in range(numb_batches): - try: - stat_data = next(iterator) - except StopIteration: - iterator = iter(dataloader) - stat_data = next(iterator) - for dd in stat_data: - if stat_data[dd] is None: - sys_stat[dd] = None - elif isinstance(stat_data[dd], torch.Tensor): - if dd not in sys_stat: - sys_stat[dd] = [] - sys_stat[dd].append(stat_data[dd]) - elif isinstance(stat_data[dd], np.float32): - sys_stat[dd] = stat_data[dd] - else: - pass - if 'atype' in sys_stat and isinstance(sys_stat['atype'], list): - collect_values = np.unique(torch.cat(sys_stat['atype']).numpy()) - collect_elements.update(collect_values) - for key in sys_stat: - if isinstance(sys_stat[key], np.float32): + stat_data = next(iterator) + for dd in stat_data: + if stat_data[dd] is None: + sys_stat[dd] = None + elif isinstance(stat_data[dd], torch.Tensor): + if dd not in sys_stat: + sys_stat[dd] = [] + sys_stat[dd].append(stat_data[dd]) + elif isinstance(stat_data[dd], np.float32): + sys_stat[dd] = stat_data[dd] + else: pass - elif sys_stat[key] is None or (isinstance(sys_stat[key], list) and (len(sys_stat[key]) == 0 or sys_stat[key][0] is None)): - sys_stat[key] = None - elif isinstance(sys_stat[key][0], torch.Tensor): - sys_stat[key] = torch.cat(sys_stat[key], dim=0) - dict_to_device(sys_stat) - lst.append(sys_stat) - element_counts = dataset.true_types() + + def finalize_stats(sys_stat): + """Finalize statistics by concatenating tensors.""" + for key in sys_stat: + if isinstance(sys_stat[key], np.float32): + pass + elif sys_stat[key] is None or (isinstance(sys_stat[key], list) and (len(sys_stat[key]) == 0 or sys_stat[key][0] is None)): + sys_stat[key] = None + elif isinstance(sys_stat[key][0], torch.Tensor): + sys_stat[key] = torch.cat(sys_stat[key], dim=0) + dict_to_device(sys_stat) + + for sys_index, (dataset, dataloader) in enumerate(zip(datasets, dataloaders)): + sys_stat = {} + with torch.device("cpu"): + process_batches(dataloader, sys_stat) + if datasets[0].mixed_type: + if 'atype' in sys_stat and isinstance(sys_stat['atype'], list): + collect_values = torch.unique(torch.cat(sys_stat['atype']).flatten(), sorted=True) + collect_elements.update(collect_values.tolist()) + + finalize_stats(sys_stat) + lst.append(sys_stat) + + if datasets[0].mixed_type: + element_counts = dataset.get_frame_index() for elem, data in element_counts.items(): - count = data["count"] - frames = data["frames"] + indices = data["indices"] total_element_types.add(elem) - if elem not in total_element_counts: - total_element_counts[elem] = {"count": 0, "frames": 0, "indices": []} - total_element_counts[elem]["count"] += count - if len(total_element_counts[elem]["indices"]) < min_frames_per_element_forstat: - total_element_counts[elem]["indices"].append({ + if elem not in global_element_counts: + global_element_counts[elem] = {"frames": [], "indices": []} + global_element_counts[elem]["frames"].extend(indices) + if len(global_element_counts[elem]["indices"]) < min_frames_per_element_forstat: + global_element_counts[elem]["indices"].append({ "sys_index": sys_index, - "frames": frames - }) - for elem, data in total_element_counts.items(): - count = data["count"] + "frames": indices + }) + if datasets[0].mixed_type and enable_element_completion: + for elem, data in global_element_counts.items(): indices_count = len(data["indices"]) if indices_count < min_frames_per_element_forstat: - log.warning(f'The number of frame with element {elem} is {indices_count}, which is less than the expected maximum value {min_frames_per_element_forstat}') + log.warning( + f'The number of frames with element {elem} is {indices_count}, ' + f'which is less than the required {min_frames_per_element_forstat}' + ) missing_elements = total_element_types - collect_elements for miss in missing_elements: - sys_indices = total_element_counts[miss].get('indices', []) + sys_indices = global_element_counts[miss].get('indices', []) for sys_info in sys_indices: - sys_index = sys_info['sys_index'] - frames = sys_info['frames'] + sys_index = sys_info['sys_index'] + frames = sys_info['frames'] sys = datasets[sys_index] - frame_data = sys.__getitem__(frames) - sys_stat_new = {} - for dd in frame_data: - if dd == "type": - continue - if frame_data[dd] is None: - sys_stat_new[dd] = None - elif isinstance(frame_data[dd], np.ndarray): - if dd not in sys_stat_new: - sys_stat_new[dd] = [] - frame_data[dd] = torch.from_numpy(frame_data[dd]) - frame_data[dd] = frame_data[dd].unsqueeze(0) - sys_stat_new[dd].append(frame_data[dd]) - elif isinstance(frame_data[dd], np.float32): - sys_stat_new[dd] = frame_data[dd] - else: - pass - for key in sys_stat_new: - if isinstance(sys_stat_new[key], np.float32): - pass - elif sys_stat_new[key] is None or sys_stat_new[key][0] is None: - sys_stat_new[key] = None - elif isinstance(frame_data[dd], torch.Tensor): - sys_stat_new[key] = torch.cat(sys_stat_new[key], dim=0) - dict_to_device(sys_stat_new) - lst.append(sys_stat_new) - else: - for i in range(len(datasets)): - sys_stat = {} - with torch.device("cpu"): - iterator = iter(dataloaders[i]) - numb_batches = min(nbatches, len(dataloaders[i])) - for _ in range(numb_batches): - try: - stat_data = next(iterator) - except StopIteration: - iterator = iter(dataloaders[i]) - stat_data = next(iterator) - for dd in stat_data: - if stat_data[dd] is None: - sys_stat[dd] = None - elif isinstance(stat_data[dd], torch.Tensor): - if dd not in sys_stat: - sys_stat[dd] = [] - sys_stat[dd].append(stat_data[dd]) - elif isinstance(stat_data[dd], np.float32): - sys_stat[dd] = stat_data[dd] + for frame in frames: + frame_data = sys.__getitem__(frame) + sys_stat_new = {} + for dd in frame_data: + if dd == "type": + continue + if frame_data[dd] is None: + sys_stat_new[dd] = None + elif isinstance(frame_data[dd], np.ndarray): + if dd not in sys_stat_new: + sys_stat_new[dd] = [] + tensor_data = torch.from_numpy(frame_data[dd]) + tensor_data = tensor_data.unsqueeze(0) + sys_stat_new[dd].append(tensor_data) + elif isinstance(frame_data[dd], np.float32): + sys_stat_new[dd] = frame_data[dd] else: pass - for key in sys_stat: - if isinstance(sys_stat[key], np.float32): - pass - elif sys_stat[key] is None or sys_stat[key][0] is None: - sys_stat[key] = None - elif isinstance(stat_data[dd], torch.Tensor): - sys_stat[key] = torch.cat(sys_stat[key], dim=0) - dict_to_device(sys_stat) - lst.append(sys_stat) + for key in sys_stat_new: + if isinstance(sys_stat_new[key], np.float32): + pass + elif sys_stat_new[key] is None or sys_stat_new[key][0] is None: + sys_stat_new[key] = None + elif isinstance(sys_stat_new[key][0], torch.Tensor): + sys_stat_new[key] = torch.cat(sys_stat_new[key], dim=0) + dict_to_device(sys_stat_new) + lst.append(sys_stat_new) return lst def _restore_from_file( diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 5b1fb05f2e..1ec86d0066 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -2832,6 +2832,13 @@ def training_args( optional=True, doc="The minimum number of frames per element used for statistics.", ), + Argument( + "enable_element_completion", + bool, + optional=False, + default=True, + doc='Whether to check elements when using the mixed type', + ), ] variants = [ Variant( From c9406e47819e946ecdd414544d514aaa306646ec Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 2 Jan 2025 11:40:31 +0000 Subject: [PATCH 16/82] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- deepmd/pt/train/training.py | 4 ++- deepmd/pt/utils/dataset.py | 6 ++-- deepmd/pt/utils/stat.py | 59 +++++++++++++++++++++++++------------ deepmd/utils/argcheck.py | 2 +- 4 files changed, 47 insertions(+), 24 deletions(-) diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index e6e3276216..8f225bea45 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -145,7 +145,9 @@ def __init__( self.min_frames_per_element_forstat = training_params.get( "min_frames_per_element_forstat", 10 ) - self.enable_element_completion = training_params.get("enable_element_completion", True) + self.enable_element_completion = training_params.get( + "enable_element_completion", True + ) self.change_bias_after_training = training_params.get( "change_bias_after_training", False ) diff --git a/deepmd/pt/utils/dataset.py b/deepmd/pt/utils/dataset.py index 2f51861799..d054fc1c17 100644 --- a/deepmd/pt/utils/dataset.py +++ b/deepmd/pt/utils/dataset.py @@ -45,7 +45,7 @@ def __getitem__(self, index): def get_frame_index(self): """ - Get the frame index and the number of frames with all the elements in the system. + Get the frame index and the number of frames with all the elements in the system. This function is only used in the mixed type. Returns @@ -59,8 +59,8 @@ def get_frame_index(self): - "indices": list of int A list of row indices where the element is found in the dataset. """ - element_counts = defaultdict(lambda: {"frames": 0, "indices": []}) - set_files = self._data_system.dirs + element_counts = defaultdict(lambda: {"frames": 0, "indices": []}) + set_files = self._data_system.dirs for set_file in set_files: element_data = self._data_system._load_type_mix(set_file) unique_elements = np.unique(element_data) diff --git a/deepmd/pt/utils/stat.py b/deepmd/pt/utils/stat.py index 81f490b1fb..4a55277574 100644 --- a/deepmd/pt/utils/stat.py +++ b/deepmd/pt/utils/stat.py @@ -37,7 +37,13 @@ log = logging.getLogger(__name__) -def make_stat_input(datasets, dataloaders, nbatches, min_frames_per_element_forstat, enable_element_completion=True): +def make_stat_input( + datasets, + dataloaders, + nbatches, + min_frames_per_element_forstat, + enable_element_completion=True, +): """Pack data for statistics. Args: @@ -59,11 +65,13 @@ def make_stat_input(datasets, dataloaders, nbatches, min_frames_per_element_fors if datasets[0].mixed_type: if enable_element_completion: log.info( - f'Element check enabled. ' - f'Verifying if frames with elements meet the set of {min_frames_per_element_forstat}.' + f"Element check enabled. " + f"Verifying if frames with elements meet the set of {min_frames_per_element_forstat}." ) else: - log.info("Element completion is disabled. Skipping missing element handling.") + log.info( + "Element completion is disabled. Skipping missing element handling." + ) def process_batches(dataloader, sys_stat): """Process batches from a dataloader to collect statistics.""" @@ -92,7 +100,10 @@ def finalize_stats(sys_stat): for key in sys_stat: if isinstance(sys_stat[key], np.float32): pass - elif sys_stat[key] is None or (isinstance(sys_stat[key], list) and (len(sys_stat[key]) == 0 or sys_stat[key][0] is None)): + elif sys_stat[key] is None or ( + isinstance(sys_stat[key], list) + and (len(sys_stat[key]) == 0 or sys_stat[key][0] is None) + ): sys_stat[key] = None elif isinstance(sys_stat[key][0], torch.Tensor): sys_stat[key] = torch.cat(sys_stat[key], dim=0) @@ -103,9 +114,11 @@ def finalize_stats(sys_stat): with torch.device("cpu"): process_batches(dataloader, sys_stat) if datasets[0].mixed_type: - if 'atype' in sys_stat and isinstance(sys_stat['atype'], list): - collect_values = torch.unique(torch.cat(sys_stat['atype']).flatten(), sorted=True) - collect_elements.update(collect_values.tolist()) + if "atype" in sys_stat and isinstance(sys_stat["atype"], list): + collect_values = torch.unique( + torch.cat(sys_stat["atype"]).flatten(), sorted=True + ) + collect_elements.update(collect_values.tolist()) finalize_stats(sys_stat) lst.append(sys_stat) @@ -116,24 +129,32 @@ def finalize_stats(sys_stat): indices = data["indices"] total_element_types.add(elem) if elem not in total_element_counts: - total_element_counts[elem] = {"count": 0, "frames": 0, "indices": []} + total_element_counts[elem] = { + "count": 0, + "frames": 0, + "indices": [], + } total_element_counts[elem]["count"] += count - if len(total_element_counts[elem]["indices"]) < min_frames_per_element_forstat: - total_element_counts[elem]["indices"].append({ - "sys_index": sys_index, - "frames": frames - }) + if ( + len(total_element_counts[elem]["indices"]) + < min_frames_per_element_forstat + ): + total_element_counts[elem]["indices"].append( + {"sys_index": sys_index, "frames": frames} + ) for elem, data in total_element_counts.items(): count = data["count"] indices_count = len(data["indices"]) if indices_count < min_frames_per_element_forstat: - log.warning(f'The number of frame with element {elem} is {indices_count}, which is less than the expected maximum value {min_frames_per_element_forstat}') + log.warning( + f"The number of frame with element {elem} is {indices_count}, which is less than the expected maximum value {min_frames_per_element_forstat}" + ) missing_elements = total_element_types - collect_elements for miss in missing_elements: - sys_indices = total_element_counts[miss].get('indices', []) + sys_indices = total_element_counts[miss].get("indices", []) for sys_info in sys_indices: - sys_index = sys_info['sys_index'] - frames = sys_info['frames'] + sys_index = sys_info["sys_index"] + frames = sys_info["frames"] sys = datasets[sys_index] frame_data = sys.__getitem__(frames) sys_stat_new = {} @@ -192,7 +213,7 @@ def finalize_stats(sys_stat): elif isinstance(stat_data[dd], torch.Tensor): sys_stat[key] = torch.cat(sys_stat[key], dim=0) dict_to_device(sys_stat) - lst.append(sys_stat) + lst.append(sys_stat) return lst diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 1ec86d0066..05911fc317 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -2837,7 +2837,7 @@ def training_args( bool, optional=False, default=True, - doc='Whether to check elements when using the mixed type', + doc="Whether to check elements when using the mixed type", ), ] variants = [ From 2224f61160f59138f76047f1d332d358c6c7b126 Mon Sep 17 00:00:00 2001 From: SumGuo Date: Thu, 2 Jan 2025 19:40:34 +0800 Subject: [PATCH 17/82] =?UTF-8?q?chec=E5=9D=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- deepmd/pt/utils/stat.py | 110 +++++++++++++++------------------------- 1 file changed, 41 insertions(+), 69 deletions(-) diff --git a/deepmd/pt/utils/stat.py b/deepmd/pt/utils/stat.py index 81f490b1fb..318b622736 100644 --- a/deepmd/pt/utils/stat.py +++ b/deepmd/pt/utils/stat.py @@ -115,84 +115,56 @@ def finalize_stats(sys_stat): for elem, data in element_counts.items(): indices = data["indices"] total_element_types.add(elem) - if elem not in total_element_counts: - total_element_counts[elem] = {"count": 0, "frames": 0, "indices": []} - total_element_counts[elem]["count"] += count - if len(total_element_counts[elem]["indices"]) < min_frames_per_element_forstat: - total_element_counts[elem]["indices"].append({ + if elem not in global_element_counts: + global_element_counts[elem] = {"frames": [], "indices": []} + global_element_counts[elem]["frames"].extend(indices) + if len(global_element_counts[elem]["indices"]) < min_frames_per_element_forstat: + global_element_counts[elem]["indices"].append({ "sys_index": sys_index, - "frames": frames - }) - for elem, data in total_element_counts.items(): - count = data["count"] + "frames": indices + }) + if datasets[0].mixed_type and enable_element_completion: + for elem, data in global_element_counts.items(): indices_count = len(data["indices"]) if indices_count < min_frames_per_element_forstat: - log.warning(f'The number of frame with element {elem} is {indices_count}, which is less than the expected maximum value {min_frames_per_element_forstat}') + log.warning( + f'The number of frames with element {elem} is {indices_count}, ' + f'which is less than the required {min_frames_per_element_forstat}' + ) missing_elements = total_element_types - collect_elements for miss in missing_elements: - sys_indices = total_element_counts[miss].get('indices', []) + sys_indices = global_element_counts[miss].get('indices', []) for sys_info in sys_indices: - sys_index = sys_info['sys_index'] - frames = sys_info['frames'] + sys_index = sys_info['sys_index'] + frames = sys_info['frames'] sys = datasets[sys_index] - frame_data = sys.__getitem__(frames) - sys_stat_new = {} - for dd in frame_data: - if dd == "type": - continue - if frame_data[dd] is None: - sys_stat_new[dd] = None - elif isinstance(frame_data[dd], np.ndarray): - if dd not in sys_stat_new: - sys_stat_new[dd] = [] - frame_data[dd] = torch.from_numpy(frame_data[dd]) - frame_data[dd] = frame_data[dd].unsqueeze(0) - sys_stat_new[dd].append(frame_data[dd]) - elif isinstance(frame_data[dd], np.float32): - sys_stat_new[dd] = frame_data[dd] - else: - pass - for key in sys_stat_new: - if isinstance(sys_stat_new[key], np.float32): - pass - elif sys_stat_new[key] is None or sys_stat_new[key][0] is None: - sys_stat_new[key] = None - elif isinstance(frame_data[dd], torch.Tensor): - sys_stat_new[key] = torch.cat(sys_stat_new[key], dim=0) - dict_to_device(sys_stat_new) - lst.append(sys_stat_new) - else: - for i in range(len(datasets)): - sys_stat = {} - with torch.device("cpu"): - iterator = iter(dataloaders[i]) - numb_batches = min(nbatches, len(dataloaders[i])) - for _ in range(numb_batches): - try: - stat_data = next(iterator) - except StopIteration: - iterator = iter(dataloaders[i]) - stat_data = next(iterator) - for dd in stat_data: - if stat_data[dd] is None: - sys_stat[dd] = None - elif isinstance(stat_data[dd], torch.Tensor): - if dd not in sys_stat: - sys_stat[dd] = [] - sys_stat[dd].append(stat_data[dd]) - elif isinstance(stat_data[dd], np.float32): - sys_stat[dd] = stat_data[dd] + for frame in frames: + frame_data = sys.__getitem__(frame) + sys_stat_new = {} + for dd in frame_data: + if dd == "type": + continue + if frame_data[dd] is None: + sys_stat_new[dd] = None + elif isinstance(frame_data[dd], np.ndarray): + if dd not in sys_stat_new: + sys_stat_new[dd] = [] + tensor_data = torch.from_numpy(frame_data[dd]) + tensor_data = tensor_data.unsqueeze(0) + sys_stat_new[dd].append(tensor_data) + elif isinstance(frame_data[dd], np.float32): + sys_stat_new[dd] = frame_data[dd] else: pass - for key in sys_stat: - if isinstance(sys_stat[key], np.float32): - pass - elif sys_stat[key] is None or sys_stat[key][0] is None: - sys_stat[key] = None - elif isinstance(stat_data[dd], torch.Tensor): - sys_stat[key] = torch.cat(sys_stat[key], dim=0) - dict_to_device(sys_stat) - lst.append(sys_stat) + for key in sys_stat_new: + if isinstance(sys_stat_new[key], np.float32): + pass + elif sys_stat_new[key] is None or sys_stat_new[key][0] is None: + sys_stat_new[key] = None + elif isinstance(sys_stat_new[key][0], torch.Tensor): + sys_stat_new[key] = torch.cat(sys_stat_new[key], dim=0) + dict_to_device(sys_stat_new) + lst.append(sys_stat_new) return lst From 9fcee849441fdbc60560ec0bb9e855d4eafc42fb Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 2 Jan 2025 11:45:18 +0000 Subject: [PATCH 18/82] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- deepmd/pt/utils/stat.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/deepmd/pt/utils/stat.py b/deepmd/pt/utils/stat.py index 0deff520e6..0f262f8de7 100644 --- a/deepmd/pt/utils/stat.py +++ b/deepmd/pt/utils/stat.py @@ -131,25 +131,27 @@ def finalize_stats(sys_stat): if elem not in global_element_counts: global_element_counts[elem] = {"frames": [], "indices": []} global_element_counts[elem]["frames"].extend(indices) - if len(global_element_counts[elem]["indices"]) < min_frames_per_element_forstat: - global_element_counts[elem]["indices"].append({ - "sys_index": sys_index, - "frames": indices - }) + if ( + len(global_element_counts[elem]["indices"]) + < min_frames_per_element_forstat + ): + global_element_counts[elem]["indices"].append( + {"sys_index": sys_index, "frames": indices} + ) if datasets[0].mixed_type and enable_element_completion: for elem, data in global_element_counts.items(): indices_count = len(data["indices"]) if indices_count < min_frames_per_element_forstat: log.warning( - f'The number of frames with element {elem} is {indices_count}, ' - f'which is less than the required {min_frames_per_element_forstat}' + f"The number of frames with element {elem} is {indices_count}, " + f"which is less than the required {min_frames_per_element_forstat}" ) missing_elements = total_element_types - collect_elements for miss in missing_elements: - sys_indices = global_element_counts[miss].get('indices', []) + sys_indices = global_element_counts[miss].get("indices", []) for sys_info in sys_indices: - sys_index = sys_info['sys_index'] - frames = sys_info['frames'] + sys_index = sys_info["sys_index"] + frames = sys_info["frames"] sys = datasets[sys_index] for frame in frames: frame_data = sys.__getitem__(frame) From fe8579e5813fd8c09d0340393665f408a770c4b2 Mon Sep 17 00:00:00 2001 From: SumGuo Date: Thu, 2 Jan 2025 19:46:36 +0800 Subject: [PATCH 19/82] check3 --- deepmd/pt/utils/stat.py | 31 ++++++++----------------------- 1 file changed, 8 insertions(+), 23 deletions(-) diff --git a/deepmd/pt/utils/stat.py b/deepmd/pt/utils/stat.py index 0deff520e6..5119e4179f 100644 --- a/deepmd/pt/utils/stat.py +++ b/deepmd/pt/utils/stat.py @@ -36,14 +36,7 @@ log = logging.getLogger(__name__) - -def make_stat_input( - datasets, - dataloaders, - nbatches, - min_frames_per_element_forstat, - enable_element_completion=True, -): +def make_stat_input(datasets, dataloaders, nbatches, min_frames_per_element_forstat, enable_element_completion=True): """Pack data for statistics. Args: @@ -65,13 +58,11 @@ def make_stat_input( if datasets[0].mixed_type: if enable_element_completion: log.info( - f"Element check enabled. " - f"Verifying if frames with elements meet the set of {min_frames_per_element_forstat}." + f'Element check enabled. ' + f'Verifying if frames with elements meet the set of {min_frames_per_element_forstat}.' ) else: - log.info( - "Element completion is disabled. Skipping missing element handling." - ) + log.info("Element completion is disabled. Skipping missing element handling.") def process_batches(dataloader, sys_stat): """Process batches from a dataloader to collect statistics.""" @@ -100,10 +91,7 @@ def finalize_stats(sys_stat): for key in sys_stat: if isinstance(sys_stat[key], np.float32): pass - elif sys_stat[key] is None or ( - isinstance(sys_stat[key], list) - and (len(sys_stat[key]) == 0 or sys_stat[key][0] is None) - ): + elif sys_stat[key] is None or (isinstance(sys_stat[key], list) and (len(sys_stat[key]) == 0 or sys_stat[key][0] is None)): sys_stat[key] = None elif isinstance(sys_stat[key][0], torch.Tensor): sys_stat[key] = torch.cat(sys_stat[key], dim=0) @@ -114,11 +102,9 @@ def finalize_stats(sys_stat): with torch.device("cpu"): process_batches(dataloader, sys_stat) if datasets[0].mixed_type: - if "atype" in sys_stat and isinstance(sys_stat["atype"], list): - collect_values = torch.unique( - torch.cat(sys_stat["atype"]).flatten(), sorted=True - ) - collect_elements.update(collect_values.tolist()) + if 'atype' in sys_stat and isinstance(sys_stat['atype'], list): + collect_values = torch.unique(torch.cat(sys_stat['atype']).flatten(), sorted=True) + collect_elements.update(collect_values.tolist()) finalize_stats(sys_stat) lst.append(sys_stat) @@ -180,7 +166,6 @@ def finalize_stats(sys_stat): lst.append(sys_stat_new) return lst - def _restore_from_file( stat_file_path: DPPath, keys: list[str] = ["energy"], From 11138ff878c1e2b61e2c03364d7663138f2ab627 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 2 Jan 2025 11:49:42 +0000 Subject: [PATCH 20/82] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- deepmd/pt/utils/stat.py | 31 +++++++++++++++++++++++-------- 1 file changed, 23 insertions(+), 8 deletions(-) diff --git a/deepmd/pt/utils/stat.py b/deepmd/pt/utils/stat.py index b8c4ade8f3..0f262f8de7 100644 --- a/deepmd/pt/utils/stat.py +++ b/deepmd/pt/utils/stat.py @@ -36,7 +36,14 @@ log = logging.getLogger(__name__) -def make_stat_input(datasets, dataloaders, nbatches, min_frames_per_element_forstat, enable_element_completion=True): + +def make_stat_input( + datasets, + dataloaders, + nbatches, + min_frames_per_element_forstat, + enable_element_completion=True, +): """Pack data for statistics. Args: @@ -58,11 +65,13 @@ def make_stat_input(datasets, dataloaders, nbatches, min_frames_per_element_fors if datasets[0].mixed_type: if enable_element_completion: log.info( - f'Element check enabled. ' - f'Verifying if frames with elements meet the set of {min_frames_per_element_forstat}.' + f"Element check enabled. " + f"Verifying if frames with elements meet the set of {min_frames_per_element_forstat}." ) else: - log.info("Element completion is disabled. Skipping missing element handling.") + log.info( + "Element completion is disabled. Skipping missing element handling." + ) def process_batches(dataloader, sys_stat): """Process batches from a dataloader to collect statistics.""" @@ -91,7 +100,10 @@ def finalize_stats(sys_stat): for key in sys_stat: if isinstance(sys_stat[key], np.float32): pass - elif sys_stat[key] is None or (isinstance(sys_stat[key], list) and (len(sys_stat[key]) == 0 or sys_stat[key][0] is None)): + elif sys_stat[key] is None or ( + isinstance(sys_stat[key], list) + and (len(sys_stat[key]) == 0 or sys_stat[key][0] is None) + ): sys_stat[key] = None elif isinstance(sys_stat[key][0], torch.Tensor): sys_stat[key] = torch.cat(sys_stat[key], dim=0) @@ -102,9 +114,11 @@ def finalize_stats(sys_stat): with torch.device("cpu"): process_batches(dataloader, sys_stat) if datasets[0].mixed_type: - if 'atype' in sys_stat and isinstance(sys_stat['atype'], list): - collect_values = torch.unique(torch.cat(sys_stat['atype']).flatten(), sorted=True) - collect_elements.update(collect_values.tolist()) + if "atype" in sys_stat and isinstance(sys_stat["atype"], list): + collect_values = torch.unique( + torch.cat(sys_stat["atype"]).flatten(), sorted=True + ) + collect_elements.update(collect_values.tolist()) finalize_stats(sys_stat) lst.append(sys_stat) @@ -168,6 +182,7 @@ def finalize_stats(sys_stat): lst.append(sys_stat_new) return lst + def _restore_from_file( stat_file_path: DPPath, keys: list[str] = ["energy"], From a4a97a3ed95f93257bf073ac1fb7eab8dbf149ee Mon Sep 17 00:00:00 2001 From: SumGuo Date: Fri, 3 Jan 2025 11:20:16 +0800 Subject: [PATCH 21/82] test --- .../sys.000000/set.000/box.npy | Bin 0 -> 308 bytes .../sys.000000/set.000/coord.npy | Bin 0 -> 548 bytes .../sys.000000/set.000/energy.npy | Bin 0 -> 148 bytes .../sys.000000/set.000/force.npy | Bin 0 -> 548 bytes .../sys.000000/set.000/real_atom_numbs.npy | Bin 0 -> 2408 bytes .../sys.000000/set.000/real_atom_types.npy | Bin 0 -> 408 bytes .../pt/mixed_type_data/sys.000000/type.raw | 7 + .../mixed_type_data/sys.000000/type_map.raw | 57 ++++++++ source/tests/pt/test_make_stat_input.py | 129 ++++++++---------- 9 files changed, 123 insertions(+), 70 deletions(-) create mode 100644 source/tests/pt/mixed_type_data/sys.000000/set.000/box.npy create mode 100644 source/tests/pt/mixed_type_data/sys.000000/set.000/coord.npy create mode 100644 source/tests/pt/mixed_type_data/sys.000000/set.000/energy.npy create mode 100644 source/tests/pt/mixed_type_data/sys.000000/set.000/force.npy create mode 100644 source/tests/pt/mixed_type_data/sys.000000/set.000/real_atom_numbs.npy create mode 100644 source/tests/pt/mixed_type_data/sys.000000/set.000/real_atom_types.npy create mode 100644 source/tests/pt/mixed_type_data/sys.000000/type.raw create mode 100644 source/tests/pt/mixed_type_data/sys.000000/type_map.raw diff --git a/source/tests/pt/mixed_type_data/sys.000000/set.000/box.npy b/source/tests/pt/mixed_type_data/sys.000000/set.000/box.npy new file mode 100644 index 0000000000000000000000000000000000000000..9a75fc05919f50d033ba8bb3b7a8b3c36c2dd8d0 GIT binary patch literal 308 zcmbR27wQ`j$;eQ~P_3SlTAW;@Zl$1ZlV+l>qoAIaUsO_*m=~X4l#&V(cT3DEP6dh= zXCxM+0{I%IItrGWItsN4WCN~x5jIB#2xtIN2WO}|R_@ugA0ouSVCnS4(WIFLuEwOf x#bH~^QwNwlh<;JqoAIaUsO_*m=~X4l#&V(cT3DEP6dh= zXCxM+0{I%IItoUHnmP)#3giN=I*~T}l}jf%ENQlN+$-nlARW5L;lE?Cj z7r%1opSH*0$o^u-BS1AjE*^C#tmtyw17u6Cd*+~9+Y4sjYFg>wGq>N-q?yH0g?Xlf z*gY@D7x$0t`?yzo|H&)!9RFeDwI%C7dIm}x$lXaDMB9Qz}V&vImyp5%}?^@RQH^m50Aw>cc& z`Au?o+=tuDTiu$D2~JlX8m6^4BKuuU*4V+hf04sBiF6ddGx5K4Fp#?KP=0Nuqmf#= zL!{O=2Zy6gj`0O94)c<)I^4e2>=-fK#o-*o0*9*D4#!xg&-=wMH9AC2EO8WFywskf zTg>5Cn6=}JB@gX)KL6!#MmNlnp?8@>@zHh%p{6Xyle`ZdKxzz68avwYeQ*G&VP4?p R*t)j$z_Izg4n?aH9094|%s>DD literal 0 HcmV?d00001 diff --git a/source/tests/pt/mixed_type_data/sys.000000/set.000/energy.npy b/source/tests/pt/mixed_type_data/sys.000000/set.000/energy.npy new file mode 100644 index 0000000000000000000000000000000000000000..831de792c4b928dcb62d01fc38afb0717b4af2e7 GIT binary patch literal 148 zcmbR27wQ`j$;eQ~P_3SlTAW;@Zl$1ZlV+l>qoAIaUsO_*m=~X4l#&V(cT3DEP6dh= rXCxM+0{I%II+{8PwF(pfuJ8qx2dy4_Il%Gf&w&rI?gt;Xm>dKE^MWUe literal 0 HcmV?d00001 diff --git a/source/tests/pt/mixed_type_data/sys.000000/set.000/force.npy b/source/tests/pt/mixed_type_data/sys.000000/set.000/force.npy new file mode 100644 index 0000000000000000000000000000000000000000..3e112883296955ddfd5affcddc7af99fcc3532d6 GIT binary patch literal 548 zcmbR27wQ`j$;eQ~P_3SlTAW;@Zl$1ZlV+l>qoAIaUsO_*m=~X4l#&V(cT3DEP6dh= zXCxM+0{I%IItoUHnmP)#3giN=&EMYHur3y|X_J_?OZ8sHw&14lJ@xL|wz=Ql?Kut< z@0Omn%T-o$Yip0c&AQ?o+q2BG%qO)3+05H_V|V-XkK3N7OxZo_x$oZHuRm<-ym)Dk ze3a_m(nbZF&3f&7{3KNN$bR;<6;XR(!`aigSG8Vk4|nu>+q+7pcK$C)ZBqV;+a|5B zwtjEDckeq9=6x;u4D4jwpY0T2Ibp?r?XQjJ+9x*NMJ#*f9-O!D5#uCV^~>w5r!6qE zednoe8(wwRwtiLJzH2Y{?5TOa&ek(xid|z4yKU0z!#17Y?Dws;T4TF+k>EbAz7*>f zdpdS8SWDROT$r&}Up8ppVeS=s<>U672=>fhC3^;TZOHfHq>yM!^nP8?pcZ literal 0 HcmV?d00001 diff --git a/source/tests/pt/mixed_type_data/sys.000000/set.000/real_atom_types.npy b/source/tests/pt/mixed_type_data/sys.000000/set.000/real_atom_types.npy new file mode 100644 index 0000000000000000000000000000000000000000..1522e39c490791693b10d7c582af28bec1146a69 GIT binary patch literal 408 zcmbR27wQ`j$;eQ~P_3SlTAW;@Zl$1ZlWC!@qoAIaUsO_*m=~X4l#&V(cT3DEP6dh= zXCxM+0{I%IItu2RItsN4WCJc?1_;1Kvq9wr(P$MYpI91Z4zcQC>R|eCxeq1|Q-|(O M9;iEsrRAXN0YbbUga7~l literal 0 HcmV?d00001 diff --git a/source/tests/pt/mixed_type_data/sys.000000/type.raw b/source/tests/pt/mixed_type_data/sys.000000/type.raw new file mode 100644 index 0000000000..a21090781f --- /dev/null +++ b/source/tests/pt/mixed_type_data/sys.000000/type.raw @@ -0,0 +1,7 @@ +0 +0 +0 +0 +0 +0 +0 diff --git a/source/tests/pt/mixed_type_data/sys.000000/type_map.raw b/source/tests/pt/mixed_type_data/sys.000000/type_map.raw new file mode 100644 index 0000000000..0233506281 --- /dev/null +++ b/source/tests/pt/mixed_type_data/sys.000000/type_map.raw @@ -0,0 +1,57 @@ +Ag +Al +As +Au +B +Bi +C +Ca +Cd +Cl +Co +Cr +Cs +Cu +Fe +Ga +Ge +H +Hf +Hg +In +Ir +K +Mg +Mn +Mo +N +Na +Nb +Ni +O +Os +P +Pb +Pd +Pt +Rb +Re +Rh +Ru +S +Sb +Sc +Se +Si +Sn +Sr +Ta +Tc +Te +Ti +Tl +V +W +Y +Zn +Zr diff --git a/source/tests/pt/test_make_stat_input.py b/source/tests/pt/test_make_stat_input.py index 4bdfc26732..802d74d5d1 100644 --- a/source/tests/pt/test_make_stat_input.py +++ b/source/tests/pt/test_make_stat_input.py @@ -1,87 +1,76 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import unittest -from collections import ( - defaultdict, -) - +import numpy as np import torch -from torch.utils.data import ( - DataLoader, -) - -from deepmd.pt.utils.stat import ( - make_stat_input, -) - - -class TestDataset: - def __init__(self, samples): - self.samples = samples - self.element_to_frames = defaultdict(list) - self.mixed_type = True - for idx, sample in enumerate(samples): - atypes = sample["atype"] - for atype in atypes: - self.element_to_frames[atype].append(idx) - - @property - def get_all_atype(self): - return set(self.element_to_frames.keys()) - - def __len__(self): - return len(self.samples) - - def __getitem__(self, idx): - sample = self.samples[idx] - return { - "atype": torch.tensor(sample["atype"], dtype=torch.long), - "energy": torch.tensor(sample["energy"], dtype=torch.float32), - } - - def true_types(self): - element_counts = defaultdict(lambda: {"count": 0, "frames": 0}) - for idx, sample in enumerate(self.samples): - atypes = sample["atype"] - unique_atypes = set(atypes) - for atype in atypes: - element_counts[atype]["count"] += 1 - for atype in unique_atypes: - element_counts[atype]["frames"] += 1 - return dict(element_counts) +from torch.utils.data import DataLoader +from deepmd.pt.utils.stat import make_stat_input,compute_output_stats +from deepmd.pt.utils.dataset import DeepmdDataSetForLoader +def collate_fn(batch): + if isinstance(batch, dict): + batch = [batch] + collated_batch = {} + for key in batch[0].keys(): + data_list = [d[key] for d in batch] + if isinstance(data_list[0], np.ndarray): + data_np = np.stack(data_list) + collated_batch[key] = torch.from_numpy(data_np) + else: + collated_batch[key] = torch.tensor(data_list) + return collated_batch class TestMakeStatInput(unittest.TestCase): - def setUp(self): - self.system = TestDataset( - [ - {"atype": [1], "energy": -1.0}, - {"atype": [2], "energy": -2.0}, - ] - ) - self.datasets = [self.system] - self.dataloaders = [ - DataLoader(self.system, batch_size=1, shuffle=False), - ] + @classmethod + def setUpClass(cls): + system_path = "mixed_type_data/sys.000000" + cls.alltype = {19, 6, 17, 12, 30, 36} + cls.datasets = [DeepmdDataSetForLoader(system=system_path)] + weights = torch.tensor([0.1] * len(cls.datasets)) + sampler = torch.utils.data.WeightedRandomSampler(weights, num_samples=len(weights), replacement=True) + cls.dataloaders = [] + for dataset in cls.datasets: + dataloader = DataLoader( + dataset, + sampler=sampler, + batch_size=1, + num_workers=0, + drop_last=False, + collate_fn=collate_fn, + pin_memory=True, + ) + cls.dataloaders.append(dataloader) def test_make_stat_input(self): nbatches = 1 lst = make_stat_input( - self.datasets, - self.dataloaders, + datasets=self.datasets, + dataloaders=self.dataloaders, nbatches=nbatches, min_frames_per_element_forstat=1, + enable_element_completion=True, ) - all_elements = self.system.get_all_atype - unique_elements = {1, 2} - self.assertEqual(unique_elements, all_elements, "make_stat_input miss elements") - - expected_true_types = { - 1: {"count": 1, "frames": 1}, - 2: {"count": 1, "frames": 1}, - } - actual_true_types = self.system.true_types() - self.assertEqual(expected_true_types, actual_true_types, "true_types is wrong") + coll_ele = set() + for i in lst: + ele = np.unique(i['atype'].cpu().numpy()) + coll_ele.update(ele) + if not coll_ele == self.alltype: + self.assertFalse('Wrong') + def test_make_stat_input_nocomplete(self): + nbatches = 1 + lst = make_stat_input( + datasets=self.datasets, + dataloaders=self.dataloaders, + nbatches=nbatches, + min_frames_per_element_forstat=1, + enable_element_completion=False, + ) + coll_ele = set() + for i in lst: + ele = np.unique(i['atype'].cpu().numpy()) + coll_ele.update(ele) + if coll_ele == self.alltype: + self.assertFalse('Wrong') if __name__ == "__main__": unittest.main() From 10e538d12b707d3d321ba2b6e3c642cc5e6406de Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 3 Jan 2025 03:21:43 +0000 Subject: [PATCH 22/82] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- source/tests/pt/test_make_stat_input.py | 43 ++++++++++++++++--------- 1 file changed, 28 insertions(+), 15 deletions(-) diff --git a/source/tests/pt/test_make_stat_input.py b/source/tests/pt/test_make_stat_input.py index 802d74d5d1..9bed26c87c 100644 --- a/source/tests/pt/test_make_stat_input.py +++ b/source/tests/pt/test_make_stat_input.py @@ -1,16 +1,25 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import unittest + import numpy as np import torch -from torch.utils.data import DataLoader -from deepmd.pt.utils.stat import make_stat_input,compute_output_stats -from deepmd.pt.utils.dataset import DeepmdDataSetForLoader +from torch.utils.data import ( + DataLoader, +) + +from deepmd.pt.utils.dataset import ( + DeepmdDataSetForLoader, +) +from deepmd.pt.utils.stat import ( + make_stat_input, +) + def collate_fn(batch): if isinstance(batch, dict): batch = [batch] collated_batch = {} - for key in batch[0].keys(): + for key in batch[0].keys(): data_list = [d[key] for d in batch] if isinstance(data_list[0], np.ndarray): data_np = np.stack(data_list) @@ -19,20 +28,23 @@ def collate_fn(batch): collated_batch[key] = torch.tensor(data_list) return collated_batch + class TestMakeStatInput(unittest.TestCase): @classmethod def setUpClass(cls): system_path = "mixed_type_data/sys.000000" cls.alltype = {19, 6, 17, 12, 30, 36} cls.datasets = [DeepmdDataSetForLoader(system=system_path)] - weights = torch.tensor([0.1] * len(cls.datasets)) - sampler = torch.utils.data.WeightedRandomSampler(weights, num_samples=len(weights), replacement=True) + weights = torch.tensor([0.1] * len(cls.datasets)) + sampler = torch.utils.data.WeightedRandomSampler( + weights, num_samples=len(weights), replacement=True + ) cls.dataloaders = [] for dataset in cls.datasets: dataloader = DataLoader( dataset, sampler=sampler, - batch_size=1, + batch_size=1, num_workers=0, drop_last=False, collate_fn=collate_fn, @@ -43,34 +55,35 @@ def setUpClass(cls): def test_make_stat_input(self): nbatches = 1 lst = make_stat_input( - datasets=self.datasets, - dataloaders=self.dataloaders, + datasets=self.datasets, + dataloaders=self.dataloaders, nbatches=nbatches, min_frames_per_element_forstat=1, enable_element_completion=True, ) coll_ele = set() for i in lst: - ele = np.unique(i['atype'].cpu().numpy()) + ele = np.unique(i["atype"].cpu().numpy()) coll_ele.update(ele) if not coll_ele == self.alltype: - self.assertFalse('Wrong') + self.assertFalse("Wrong") def test_make_stat_input_nocomplete(self): nbatches = 1 lst = make_stat_input( - datasets=self.datasets, - dataloaders=self.dataloaders, + datasets=self.datasets, + dataloaders=self.dataloaders, nbatches=nbatches, min_frames_per_element_forstat=1, enable_element_completion=False, ) coll_ele = set() for i in lst: - ele = np.unique(i['atype'].cpu().numpy()) + ele = np.unique(i["atype"].cpu().numpy()) coll_ele.update(ele) if coll_ele == self.alltype: - self.assertFalse('Wrong') + self.assertFalse("Wrong") + if __name__ == "__main__": unittest.main() From 203dc4ef9a907b57135365e4cae2e061e8ebb6f3 Mon Sep 17 00:00:00 2001 From: SumGuo Date: Fri, 3 Jan 2025 14:56:01 +0800 Subject: [PATCH 23/82] ttt --- source/tests/pt/test_make_stat_input.py | 39 +++++++++++++------------ 1 file changed, 20 insertions(+), 19 deletions(-) diff --git a/source/tests/pt/test_make_stat_input.py b/source/tests/pt/test_make_stat_input.py index 802d74d5d1..a0417407f9 100644 --- a/source/tests/pt/test_make_stat_input.py +++ b/source/tests/pt/test_make_stat_input.py @@ -5,6 +5,8 @@ from torch.utils.data import DataLoader from deepmd.pt.utils.stat import make_stat_input,compute_output_stats from deepmd.pt.utils.dataset import DeepmdDataSetForLoader +from deepmd.utils.data import DataRequirementItem + def collate_fn(batch): if isinstance(batch, dict): @@ -23,8 +25,17 @@ class TestMakeStatInput(unittest.TestCase): @classmethod def setUpClass(cls): system_path = "mixed_type_data/sys.000000" - cls.alltype = {19, 6, 17, 12, 30, 36} - cls.datasets = [DeepmdDataSetForLoader(system=system_path)] + cls.datasets = DeepmdDataSetForLoader(system=system_path) + data_requirements = [ + DataRequirementItem( + "energy", + ndof=1, + atomic=False, + ), + + ] + cls.datasets.add_data_requirement(data_requirements) + cls.datasets=[cls.datasets] weights = torch.tensor([0.1] * len(cls.datasets)) sampler = torch.utils.data.WeightedRandomSampler(weights, num_samples=len(weights), replacement=True) cls.dataloaders = [] @@ -41,36 +52,26 @@ def setUpClass(cls): cls.dataloaders.append(dataloader) def test_make_stat_input(self): - nbatches = 1 lst = make_stat_input( datasets=self.datasets, dataloaders=self.dataloaders, - nbatches=nbatches, + nbatches=1, min_frames_per_element_forstat=1, enable_element_completion=True, ) - coll_ele = set() - for i in lst: - ele = np.unique(i['atype'].cpu().numpy()) - coll_ele.update(ele) - if not coll_ele == self.alltype: - self.assertFalse('Wrong') + bias,_=compute_output_stats(lst,ntypes=57) + print(bias) def test_make_stat_input_nocomplete(self): - nbatches = 1 lst = make_stat_input( datasets=self.datasets, dataloaders=self.dataloaders, - nbatches=nbatches, + nbatches=1, min_frames_per_element_forstat=1, enable_element_completion=False, ) - coll_ele = set() - for i in lst: - ele = np.unique(i['atype'].cpu().numpy()) - coll_ele.update(ele) - if coll_ele == self.alltype: - self.assertFalse('Wrong') + bias,_=compute_output_stats(lst,ntypes=57) + print(bias) if __name__ == "__main__": - unittest.main() + unittest.main() \ No newline at end of file From 6a655618790f20b3c8eb51df4f88fbc2b029584d Mon Sep 17 00:00:00 2001 From: SumGuo Date: Fri, 3 Jan 2025 14:58:24 +0800 Subject: [PATCH 24/82] t --- source/tests/pt/test_make_stat_input.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/source/tests/pt/test_make_stat_input.py b/source/tests/pt/test_make_stat_input.py index 7e57a27419..63b8939101 100644 --- a/source/tests/pt/test_make_stat_input.py +++ b/source/tests/pt/test_make_stat_input.py @@ -1,17 +1,17 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import unittest - import numpy as np import torch from torch.utils.data import DataLoader from deepmd.pt.utils.stat import make_stat_input,compute_output_stats from deepmd.pt.utils.dataset import DeepmdDataSetForLoader +from deepmd.utils.data import DataRequirementItem def collate_fn(batch): if isinstance(batch, dict): batch = [batch] collated_batch = {} - for key in batch[0].keys(): + for key in batch[0].keys(): data_list = [d[key] for d in batch] if isinstance(data_list[0], np.ndarray): data_np = np.stack(data_list) @@ -20,7 +20,6 @@ def collate_fn(batch): collated_batch[key] = torch.tensor(data_list) return collated_batch - class TestMakeStatInput(unittest.TestCase): @classmethod def setUpClass(cls): @@ -43,7 +42,7 @@ def setUpClass(cls): dataloader = DataLoader( dataset, sampler=sampler, - batch_size=1, + batch_size=1, num_workers=0, drop_last=False, collate_fn=collate_fn, @@ -55,7 +54,7 @@ def test_make_stat_input(self): lst = make_stat_input( datasets=self.datasets, dataloaders=self.dataloaders, - nbatches=nbatches, + nbatches=1, min_frames_per_element_forstat=1, enable_element_completion=True, ) From 603aee981bb499d7114d5005291c2bb17691bcbb Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 3 Jan 2025 06:59:49 +0000 Subject: [PATCH 25/82] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- source/tests/pt/test_make_stat_input.py | 57 ++++++++++++++++--------- 1 file changed, 36 insertions(+), 21 deletions(-) diff --git a/source/tests/pt/test_make_stat_input.py b/source/tests/pt/test_make_stat_input.py index 63b8939101..f695533ec4 100644 --- a/source/tests/pt/test_make_stat_input.py +++ b/source/tests/pt/test_make_stat_input.py @@ -1,17 +1,29 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import unittest + import numpy as np import torch -from torch.utils.data import DataLoader -from deepmd.pt.utils.stat import make_stat_input,compute_output_stats -from deepmd.pt.utils.dataset import DeepmdDataSetForLoader -from deepmd.utils.data import DataRequirementItem +from torch.utils.data import ( + DataLoader, +) + +from deepmd.pt.utils.dataset import ( + DeepmdDataSetForLoader, +) +from deepmd.pt.utils.stat import ( + compute_output_stats, + make_stat_input, +) +from deepmd.utils.data import ( + DataRequirementItem, +) + def collate_fn(batch): if isinstance(batch, dict): batch = [batch] collated_batch = {} - for key in batch[0].keys(): + for key in batch[0].keys(): data_list = [d[key] for d in batch] if isinstance(data_list[0], np.ndarray): data_np = np.stack(data_list) @@ -20,6 +32,7 @@ def collate_fn(batch): collated_batch[key] = torch.tensor(data_list) return collated_batch + class TestMakeStatInput(unittest.TestCase): @classmethod def setUpClass(cls): @@ -27,22 +40,23 @@ def setUpClass(cls): cls.datasets = DeepmdDataSetForLoader(system=system_path) data_requirements = [ DataRequirementItem( - "energy", - ndof=1, - atomic=False, - ), - + "energy", + ndof=1, + atomic=False, + ), ] cls.datasets.add_data_requirement(data_requirements) - cls.datasets=[cls.datasets] - weights = torch.tensor([0.1] * len(cls.datasets)) - sampler = torch.utils.data.WeightedRandomSampler(weights, num_samples=len(weights), replacement=True) + cls.datasets = [cls.datasets] + weights = torch.tensor([0.1] * len(cls.datasets)) + sampler = torch.utils.data.WeightedRandomSampler( + weights, num_samples=len(weights), replacement=True + ) cls.dataloaders = [] for dataset in cls.datasets: dataloader = DataLoader( dataset, sampler=sampler, - batch_size=1, + batch_size=1, num_workers=0, drop_last=False, collate_fn=collate_fn, @@ -52,25 +66,26 @@ def setUpClass(cls): def test_make_stat_input(self): lst = make_stat_input( - datasets=self.datasets, - dataloaders=self.dataloaders, + datasets=self.datasets, + dataloaders=self.dataloaders, nbatches=1, min_frames_per_element_forstat=1, enable_element_completion=True, ) - bias,_=compute_output_stats(lst,ntypes=57) + bias, _ = compute_output_stats(lst, ntypes=57) print(bias) def test_make_stat_input_nocomplete(self): lst = make_stat_input( - datasets=self.datasets, - dataloaders=self.dataloaders, + datasets=self.datasets, + dataloaders=self.dataloaders, nbatches=1, min_frames_per_element_forstat=1, enable_element_completion=False, ) - bias,_=compute_output_stats(lst,ntypes=57) + bias, _ = compute_output_stats(lst, ntypes=57) print(bias) + if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main() From 1c103c42d1af39498d78a20c12578246103dda25 Mon Sep 17 00:00:00 2001 From: SumGuo Date: Fri, 3 Jan 2025 15:23:58 +0800 Subject: [PATCH 26/82] d --- deepmd/pt/utils/stat.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/deepmd/pt/utils/stat.py b/deepmd/pt/utils/stat.py index 0f262f8de7..82a64fe494 100644 --- a/deepmd/pt/utils/stat.py +++ b/deepmd/pt/utils/stat.py @@ -123,6 +123,7 @@ def finalize_stats(sys_stat): finalize_stats(sys_stat) lst.append(sys_stat) + #get frame index if datasets[0].mixed_type: element_counts = dataset.get_frame_index() for elem, data in element_counts.items(): @@ -138,6 +139,7 @@ def finalize_stats(sys_stat): global_element_counts[elem]["indices"].append( {"sys_index": sys_index, "frames": indices} ) + #Check whether the element used for statistics is complete in mixed_type if datasets[0].mixed_type and enable_element_completion: for elem, data in global_element_counts.items(): indices_count = len(data["indices"]) From e3a1c9b14849ac67713e8e0c4caffa9bbc8bd281 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 3 Jan 2025 07:25:26 +0000 Subject: [PATCH 27/82] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- deepmd/pt/utils/stat.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/deepmd/pt/utils/stat.py b/deepmd/pt/utils/stat.py index 82a64fe494..ec461f279a 100644 --- a/deepmd/pt/utils/stat.py +++ b/deepmd/pt/utils/stat.py @@ -123,7 +123,7 @@ def finalize_stats(sys_stat): finalize_stats(sys_stat) lst.append(sys_stat) - #get frame index + # get frame index if datasets[0].mixed_type: element_counts = dataset.get_frame_index() for elem, data in element_counts.items(): @@ -139,7 +139,7 @@ def finalize_stats(sys_stat): global_element_counts[elem]["indices"].append( {"sys_index": sys_index, "frames": indices} ) - #Check whether the element used for statistics is complete in mixed_type + # Check whether the element used for statistics is complete in mixed_type if datasets[0].mixed_type and enable_element_completion: for elem, data in global_element_counts.items(): indices_count = len(data["indices"]) From 533e95e15f2168732f3b669952b0f217d8a93926 Mon Sep 17 00:00:00 2001 From: SumGuo Date: Fri, 3 Jan 2025 17:09:42 +0800 Subject: [PATCH 28/82] ll --- source/tests/pt/test_make_stat_input.py | 24 +++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/source/tests/pt/test_make_stat_input.py b/source/tests/pt/test_make_stat_input.py index f695533ec4..3161feab1a 100644 --- a/source/tests/pt/test_make_stat_input.py +++ b/source/tests/pt/test_make_stat_input.py @@ -32,7 +32,6 @@ def collate_fn(batch): collated_batch[key] = torch.tensor(data_list) return collated_batch - class TestMakeStatInput(unittest.TestCase): @classmethod def setUpClass(cls): @@ -63,6 +62,9 @@ def setUpClass(cls): pin_memory=True, ) cls.dataloaders.append(dataloader) + + def count_non_zero_elements(self, tensor, threshold=1e-8): + return torch.sum(torch.abs(tensor) > threshold).item() def test_make_stat_input(self): lst = make_stat_input( @@ -73,8 +75,14 @@ def test_make_stat_input(self): enable_element_completion=True, ) bias, _ = compute_output_stats(lst, ntypes=57) - print(bias) - + energy = bias.get('energy') + self.assertIsNotNone(energy, "'energy' key not found in bias dictionary.") + non_zero_count = self.count_non_zero_elements(energy) + self.assertEqual( + non_zero_count, + 6, + f"Expected exactly 7 non-zero elements, but got {non_zero_count}." + ) def test_make_stat_input_nocomplete(self): lst = make_stat_input( datasets=self.datasets, @@ -84,8 +92,14 @@ def test_make_stat_input_nocomplete(self): enable_element_completion=False, ) bias, _ = compute_output_stats(lst, ntypes=57) - print(bias) - + energy = bias.get('energy') + self.assertIsNotNone(energy, "'energy' key not found in bias dictionary.") + non_zero_count = self.count_non_zero_elements(energy) + self.assertLess( + non_zero_count, + 6, + f"Expected fewer than 7 non-zero elements, but got {non_zero_count}." + ) if __name__ == "__main__": unittest.main() From 714c197fc67bee7ca12d6c94b68e1e56958511e4 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 3 Jan 2025 09:11:09 +0000 Subject: [PATCH 29/82] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- source/tests/pt/test_make_stat_input.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/source/tests/pt/test_make_stat_input.py b/source/tests/pt/test_make_stat_input.py index 3161feab1a..cf673637e3 100644 --- a/source/tests/pt/test_make_stat_input.py +++ b/source/tests/pt/test_make_stat_input.py @@ -32,6 +32,7 @@ def collate_fn(batch): collated_batch[key] = torch.tensor(data_list) return collated_batch + class TestMakeStatInput(unittest.TestCase): @classmethod def setUpClass(cls): @@ -62,7 +63,7 @@ def setUpClass(cls): pin_memory=True, ) cls.dataloaders.append(dataloader) - + def count_non_zero_elements(self, tensor, threshold=1e-8): return torch.sum(torch.abs(tensor) > threshold).item() @@ -75,14 +76,15 @@ def test_make_stat_input(self): enable_element_completion=True, ) bias, _ = compute_output_stats(lst, ntypes=57) - energy = bias.get('energy') + energy = bias.get("energy") self.assertIsNotNone(energy, "'energy' key not found in bias dictionary.") non_zero_count = self.count_non_zero_elements(energy) self.assertEqual( - non_zero_count, - 6, - f"Expected exactly 7 non-zero elements, but got {non_zero_count}." + non_zero_count, + 6, + f"Expected exactly 7 non-zero elements, but got {non_zero_count}.", ) + def test_make_stat_input_nocomplete(self): lst = make_stat_input( datasets=self.datasets, @@ -92,14 +94,15 @@ def test_make_stat_input_nocomplete(self): enable_element_completion=False, ) bias, _ = compute_output_stats(lst, ntypes=57) - energy = bias.get('energy') + energy = bias.get("energy") self.assertIsNotNone(energy, "'energy' key not found in bias dictionary.") non_zero_count = self.count_non_zero_elements(energy) self.assertLess( - non_zero_count, - 6, - f"Expected fewer than 7 non-zero elements, but got {non_zero_count}." + non_zero_count, + 6, + f"Expected fewer than 7 non-zero elements, but got {non_zero_count}.", ) + if __name__ == "__main__": unittest.main() From 6713c1a7b5bdf17c5dba88ce341cefc43ef8b4d8 Mon Sep 17 00:00:00 2001 From: SumGuo Date: Sat, 4 Jan 2025 11:22:43 +0800 Subject: [PATCH 30/82] last --- deepmd/pt/utils/stat.py | 140 ++++++++++++++++++++------------------- deepmd/utils/argcheck.py | 2 +- 2 files changed, 74 insertions(+), 68 deletions(-) diff --git a/deepmd/pt/utils/stat.py b/deepmd/pt/utils/stat.py index ec461f279a..45640703a1 100644 --- a/deepmd/pt/utils/stat.py +++ b/deepmd/pt/utils/stat.py @@ -36,14 +36,7 @@ log = logging.getLogger(__name__) - -def make_stat_input( - datasets, - dataloaders, - nbatches, - min_frames_per_element_forstat, - enable_element_completion=True, -): +def make_stat_input(datasets, dataloaders, nbatches, min_frames_per_element_forstat, enable_element_completion=True): """Pack data for statistics. Args: @@ -65,13 +58,11 @@ def make_stat_input( if datasets[0].mixed_type: if enable_element_completion: log.info( - f"Element check enabled. " - f"Verifying if frames with elements meet the set of {min_frames_per_element_forstat}." + f'Element check enabled. ' + f'Verifying if frames with elements meet the set of {min_frames_per_element_forstat}.' ) else: - log.info( - "Element completion is disabled. Skipping missing element handling." - ) + log.info("Element completion is disabled. Skipping missing element handling.") def process_batches(dataloader, sys_stat): """Process batches from a dataloader to collect statistics.""" @@ -100,10 +91,7 @@ def finalize_stats(sys_stat): for key in sys_stat: if isinstance(sys_stat[key], np.float32): pass - elif sys_stat[key] is None or ( - isinstance(sys_stat[key], list) - and (len(sys_stat[key]) == 0 or sys_stat[key][0] is None) - ): + elif sys_stat[key] is None or (isinstance(sys_stat[key], list) and (len(sys_stat[key]) == 0 or sys_stat[key][0] is None)): sys_stat[key] = None elif isinstance(sys_stat[key][0], torch.Tensor): sys_stat[key] = torch.cat(sys_stat[key], dim=0) @@ -114,77 +102,95 @@ def finalize_stats(sys_stat): with torch.device("cpu"): process_batches(dataloader, sys_stat) if datasets[0].mixed_type: - if "atype" in sys_stat and isinstance(sys_stat["atype"], list): - collect_values = torch.unique( - torch.cat(sys_stat["atype"]).flatten(), sorted=True - ) - collect_elements.update(collect_values.tolist()) + if 'atype' in sys_stat and isinstance(sys_stat['atype'], list): + collect_values = torch.unique(torch.cat(sys_stat['atype']).flatten(), sorted=True) + collect_elements.update(collect_values.tolist()) finalize_stats(sys_stat) lst.append(sys_stat) - # get frame index - if datasets[0].mixed_type: + #get frame index + if datasets[0].mixed_type and enable_element_completion: element_counts = dataset.get_frame_index() for elem, data in element_counts.items(): indices = data["indices"] + count = data["frames"] total_element_types.add(elem) if elem not in global_element_counts: - global_element_counts[elem] = {"frames": [], "indices": []} - global_element_counts[elem]["frames"].extend(indices) - if ( - len(global_element_counts[elem]["indices"]) - < min_frames_per_element_forstat - ): - global_element_counts[elem]["indices"].append( - {"sys_index": sys_index, "frames": indices} - ) - # Check whether the element used for statistics is complete in mixed_type + global_element_counts[elem] = {"count": 0, "indices": []} + if count > min_frames_per_element_forstat: + global_element_counts[elem]["count"] += min_frames_per_element_forstat + indices = indices[:min_frames_per_element_forstat] + global_element_counts[elem]["indices"].append({ + "sys_index": sys_index, + "frames": indices + }) + else: + global_element_counts[elem]["count"] += count + global_element_counts[elem]["indices"].append({ + "sys_index": sys_index, + "frames": indices + }) + else: + if global_element_counts[elem]["count"] >= min_frames_per_element_forstat: + pass + else: + global_element_counts[elem]["count"] += count + global_element_counts[elem]["indices"].append({ + "sys_index": sys_index, + "frames": indices + }) + for key, value in global_element_counts.items(): + print(f"{key}: {value}\n") if datasets[0].mixed_type and enable_element_completion: for elem, data in global_element_counts.items(): - indices_count = len(data["indices"]) + indices_count = data["count"] if indices_count < min_frames_per_element_forstat: log.warning( - f"The number of frames with element {elem} is {indices_count}, " - f"which is less than the required {min_frames_per_element_forstat}" + f'The number of frames with element {elem} is {indices_count}, ' + f'which is less than the required {min_frames_per_element_forstat}' ) missing_elements = total_element_types - collect_elements for miss in missing_elements: - sys_indices = global_element_counts[miss].get("indices", []) + sys_indices = global_element_counts[miss].get('indices', []) + newele_counter = 0 for sys_info in sys_indices: - sys_index = sys_info["sys_index"] - frames = sys_info["frames"] + sys_index = sys_info['sys_index'] + frames = sys_info['frames'] sys = datasets[sys_index] for frame in frames: - frame_data = sys.__getitem__(frame) - sys_stat_new = {} - for dd in frame_data: - if dd == "type": - continue - if frame_data[dd] is None: - sys_stat_new[dd] = None - elif isinstance(frame_data[dd], np.ndarray): - if dd not in sys_stat_new: - sys_stat_new[dd] = [] - tensor_data = torch.from_numpy(frame_data[dd]) - tensor_data = tensor_data.unsqueeze(0) - sys_stat_new[dd].append(tensor_data) - elif isinstance(frame_data[dd], np.float32): - sys_stat_new[dd] = frame_data[dd] - else: - pass - for key in sys_stat_new: - if isinstance(sys_stat_new[key], np.float32): - pass - elif sys_stat_new[key] is None or sys_stat_new[key][0] is None: - sys_stat_new[key] = None - elif isinstance(sys_stat_new[key][0], torch.Tensor): - sys_stat_new[key] = torch.cat(sys_stat_new[key], dim=0) - dict_to_device(sys_stat_new) - lst.append(sys_stat_new) + newele_counter += 1 + if not newele_counter > min_frames_per_element_forstat: + frame_data = sys.__getitem__(frame) + sys_stat_new = {} + for dd in frame_data: + if dd == "type": + continue + if frame_data[dd] is None: + sys_stat_new[dd] = None + elif isinstance(frame_data[dd], np.ndarray): + if dd not in sys_stat_new: + sys_stat_new[dd] = [] + tensor_data = torch.from_numpy(frame_data[dd]) + tensor_data = tensor_data.unsqueeze(0) + sys_stat_new[dd].append(tensor_data) + elif isinstance(frame_data[dd], np.float32): + sys_stat_new[dd] = frame_data[dd] + else: + pass + for key in sys_stat_new: + if isinstance(sys_stat_new[key], np.float32): + pass + elif sys_stat_new[key] is None or sys_stat_new[key][0] is None: + sys_stat_new[key] = None + elif isinstance(sys_stat_new[key][0], torch.Tensor): + sys_stat_new[key] = torch.cat(sys_stat_new[key], dim=0) + dict_to_device(sys_stat_new) + lst.append(sys_stat_new) + else: + break return lst - def _restore_from_file( stat_file_path: DPPath, keys: list[str] = ["energy"], diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 05911fc317..8868c9e590 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -2830,7 +2830,7 @@ def training_args( "min_frames_per_element_forstat", int, optional=True, - doc="The minimum number of frames per element used for statistics.", + doc="The minimum number of frames per element used for statistics when using the mixed type.", ), Argument( "enable_element_completion", From 1c15cf0e9699e7e583db8aa366ee83807ad9f3c2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 4 Jan 2025 03:24:13 +0000 Subject: [PATCH 31/82] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- deepmd/pt/utils/stat.py | 78 ++++++++++++++++++++++++++--------------- 1 file changed, 49 insertions(+), 29 deletions(-) diff --git a/deepmd/pt/utils/stat.py b/deepmd/pt/utils/stat.py index 45640703a1..d0893947c9 100644 --- a/deepmd/pt/utils/stat.py +++ b/deepmd/pt/utils/stat.py @@ -36,7 +36,14 @@ log = logging.getLogger(__name__) -def make_stat_input(datasets, dataloaders, nbatches, min_frames_per_element_forstat, enable_element_completion=True): + +def make_stat_input( + datasets, + dataloaders, + nbatches, + min_frames_per_element_forstat, + enable_element_completion=True, +): """Pack data for statistics. Args: @@ -58,11 +65,13 @@ def make_stat_input(datasets, dataloaders, nbatches, min_frames_per_element_fors if datasets[0].mixed_type: if enable_element_completion: log.info( - f'Element check enabled. ' - f'Verifying if frames with elements meet the set of {min_frames_per_element_forstat}.' + f"Element check enabled. " + f"Verifying if frames with elements meet the set of {min_frames_per_element_forstat}." ) else: - log.info("Element completion is disabled. Skipping missing element handling.") + log.info( + "Element completion is disabled. Skipping missing element handling." + ) def process_batches(dataloader, sys_stat): """Process batches from a dataloader to collect statistics.""" @@ -91,7 +100,10 @@ def finalize_stats(sys_stat): for key in sys_stat: if isinstance(sys_stat[key], np.float32): pass - elif sys_stat[key] is None or (isinstance(sys_stat[key], list) and (len(sys_stat[key]) == 0 or sys_stat[key][0] is None)): + elif sys_stat[key] is None or ( + isinstance(sys_stat[key], list) + and (len(sys_stat[key]) == 0 or sys_stat[key][0] is None) + ): sys_stat[key] = None elif isinstance(sys_stat[key][0], torch.Tensor): sys_stat[key] = torch.cat(sys_stat[key], dim=0) @@ -102,14 +114,16 @@ def finalize_stats(sys_stat): with torch.device("cpu"): process_batches(dataloader, sys_stat) if datasets[0].mixed_type: - if 'atype' in sys_stat and isinstance(sys_stat['atype'], list): - collect_values = torch.unique(torch.cat(sys_stat['atype']).flatten(), sorted=True) - collect_elements.update(collect_values.tolist()) + if "atype" in sys_stat and isinstance(sys_stat["atype"], list): + collect_values = torch.unique( + torch.cat(sys_stat["atype"]).flatten(), sorted=True + ) + collect_elements.update(collect_values.tolist()) finalize_stats(sys_stat) lst.append(sys_stat) - #get frame index + # get frame index if datasets[0].mixed_type and enable_element_completion: element_counts = dataset.get_frame_index() for elem, data in element_counts.items(): @@ -119,27 +133,29 @@ def finalize_stats(sys_stat): if elem not in global_element_counts: global_element_counts[elem] = {"count": 0, "indices": []} if count > min_frames_per_element_forstat: - global_element_counts[elem]["count"] += min_frames_per_element_forstat + global_element_counts[elem]["count"] += ( + min_frames_per_element_forstat + ) indices = indices[:min_frames_per_element_forstat] - global_element_counts[elem]["indices"].append({ - "sys_index": sys_index, - "frames": indices - }) + global_element_counts[elem]["indices"].append( + {"sys_index": sys_index, "frames": indices} + ) else: global_element_counts[elem]["count"] += count - global_element_counts[elem]["indices"].append({ - "sys_index": sys_index, - "frames": indices - }) + global_element_counts[elem]["indices"].append( + {"sys_index": sys_index, "frames": indices} + ) else: - if global_element_counts[elem]["count"] >= min_frames_per_element_forstat: + if ( + global_element_counts[elem]["count"] + >= min_frames_per_element_forstat + ): pass else: global_element_counts[elem]["count"] += count - global_element_counts[elem]["indices"].append({ - "sys_index": sys_index, - "frames": indices - }) + global_element_counts[elem]["indices"].append( + {"sys_index": sys_index, "frames": indices} + ) for key, value in global_element_counts.items(): print(f"{key}: {value}\n") if datasets[0].mixed_type and enable_element_completion: @@ -147,16 +163,16 @@ def finalize_stats(sys_stat): indices_count = data["count"] if indices_count < min_frames_per_element_forstat: log.warning( - f'The number of frames with element {elem} is {indices_count}, ' - f'which is less than the required {min_frames_per_element_forstat}' + f"The number of frames with element {elem} is {indices_count}, " + f"which is less than the required {min_frames_per_element_forstat}" ) missing_elements = total_element_types - collect_elements for miss in missing_elements: - sys_indices = global_element_counts[miss].get('indices', []) + sys_indices = global_element_counts[miss].get("indices", []) newele_counter = 0 for sys_info in sys_indices: - sys_index = sys_info['sys_index'] - frames = sys_info['frames'] + sys_index = sys_info["sys_index"] + frames = sys_info["frames"] sys = datasets[sys_index] for frame in frames: newele_counter += 1 @@ -181,7 +197,10 @@ def finalize_stats(sys_stat): for key in sys_stat_new: if isinstance(sys_stat_new[key], np.float32): pass - elif sys_stat_new[key] is None or sys_stat_new[key][0] is None: + elif ( + sys_stat_new[key] is None + or sys_stat_new[key][0] is None + ): sys_stat_new[key] = None elif isinstance(sys_stat_new[key][0], torch.Tensor): sys_stat_new[key] = torch.cat(sys_stat_new[key], dim=0) @@ -191,6 +210,7 @@ def finalize_stats(sys_stat): break return lst + def _restore_from_file( stat_file_path: DPPath, keys: list[str] = ["energy"], From 33c716d0f6574bf578b60280ff49b25c8f157d9d Mon Sep 17 00:00:00 2001 From: SumGuo Date: Sat, 4 Jan 2025 11:25:16 +0800 Subject: [PATCH 32/82] q --- deepmd/pt/utils/stat.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/deepmd/pt/utils/stat.py b/deepmd/pt/utils/stat.py index 45640703a1..5556c1077b 100644 --- a/deepmd/pt/utils/stat.py +++ b/deepmd/pt/utils/stat.py @@ -140,8 +140,6 @@ def finalize_stats(sys_stat): "sys_index": sys_index, "frames": indices }) - for key, value in global_element_counts.items(): - print(f"{key}: {value}\n") if datasets[0].mixed_type and enable_element_completion: for elem, data in global_element_counts.items(): indices_count = data["count"] From 6bbced83dd2996e591c6f4854c11d288b0f6f3dd Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 4 Jan 2025 03:27:17 +0000 Subject: [PATCH 33/82] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- deepmd/pt/utils/stat.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/deepmd/pt/utils/stat.py b/deepmd/pt/utils/stat.py index c7ffd3a8cf..cf88659796 100644 --- a/deepmd/pt/utils/stat.py +++ b/deepmd/pt/utils/stat.py @@ -153,10 +153,9 @@ def finalize_stats(sys_stat): pass else: global_element_counts[elem]["count"] += count - global_element_counts[elem]["indices"].append({ - "sys_index": sys_index, - "frames": indices - }) + global_element_counts[elem]["indices"].append( + {"sys_index": sys_index, "frames": indices} + ) if datasets[0].mixed_type and enable_element_completion: for elem, data in global_element_counts.items(): indices_count = data["count"] From b462c9773044439576157c7e8171244dd3140dcc Mon Sep 17 00:00:00 2001 From: SumGuo Date: Sat, 4 Jan 2025 13:34:34 +0800 Subject: [PATCH 34/82] ll --- deepmd/pt/utils/stat.py | 94 ++++++++++++++++------------------------- 1 file changed, 37 insertions(+), 57 deletions(-) diff --git a/deepmd/pt/utils/stat.py b/deepmd/pt/utils/stat.py index c7ffd3a8cf..eeac9b16a4 100644 --- a/deepmd/pt/utils/stat.py +++ b/deepmd/pt/utils/stat.py @@ -36,15 +36,9 @@ log = logging.getLogger(__name__) - -def make_stat_input( - datasets, - dataloaders, - nbatches, - min_frames_per_element_forstat, - enable_element_completion=True, -): +def make_stat_input(datasets, dataloaders, nbatches, min_frames_per_element_forstat, enable_element_completion=True): """Pack data for statistics. + Element checking is only enabled with mixed_type. Args: - datasets: A list of datasets to analyze. @@ -59,19 +53,17 @@ def make_stat_input( """ lst = [] log.info(f"Packing data for statistics from {len(datasets)} systems") - collect_elements = set() total_element_types = set() global_element_counts = {} + collect_ele = defaultdict(int) if datasets[0].mixed_type: if enable_element_completion: log.info( - f"Element check enabled. " - f"Verifying if frames with elements meet the set of {min_frames_per_element_forstat}." + f'Element check enabled. ' + f'Verifying if frames with elements meet the set of {min_frames_per_element_forstat}.' ) else: - log.info( - "Element completion is disabled. Skipping missing element handling." - ) + log.info("Element completion is disabled. Skipping missing element handling.") def process_batches(dataloader, sys_stat): """Process batches from a dataloader to collect statistics.""" @@ -100,10 +92,7 @@ def finalize_stats(sys_stat): for key in sys_stat: if isinstance(sys_stat[key], np.float32): pass - elif sys_stat[key] is None or ( - isinstance(sys_stat[key], list) - and (len(sys_stat[key]) == 0 or sys_stat[key][0] is None) - ): + elif sys_stat[key] is None or (isinstance(sys_stat[key], list) and (len(sys_stat[key]) == 0 or sys_stat[key][0] is None)): sys_stat[key] = None elif isinstance(sys_stat[key][0], torch.Tensor): sys_stat[key] = torch.cat(sys_stat[key], dim=0) @@ -113,17 +102,17 @@ def finalize_stats(sys_stat): sys_stat = {} with torch.device("cpu"): process_batches(dataloader, sys_stat) - if datasets[0].mixed_type: - if "atype" in sys_stat and isinstance(sys_stat["atype"], list): - collect_values = torch.unique( - torch.cat(sys_stat["atype"]).flatten(), sorted=True - ) - collect_elements.update(collect_values.tolist()) - + if datasets[0].mixed_type and enable_element_completion: + element_data = torch.cat(sys_stat['atype'], dim=0) + collect_values = torch.unique(element_data.flatten(), sorted=True) + for elem in collect_values.tolist(): + frames_with_elem = torch.any(element_data == elem, dim=1) + row_indices = torch.where(frames_with_elem)[0] + collect_ele[elem] += len(row_indices) finalize_stats(sys_stat) lst.append(sys_stat) - - # get frame index + + #get frame index if datasets[0].mixed_type and enable_element_completion: element_counts = dataset.get_frame_index() for elem, data in element_counts.items(): @@ -133,23 +122,20 @@ def finalize_stats(sys_stat): if elem not in global_element_counts: global_element_counts[elem] = {"count": 0, "indices": []} if count > min_frames_per_element_forstat: - global_element_counts[elem]["count"] += ( - min_frames_per_element_forstat - ) + global_element_counts[elem]["count"] += min_frames_per_element_forstat indices = indices[:min_frames_per_element_forstat] - global_element_counts[elem]["indices"].append( - {"sys_index": sys_index, "frames": indices} - ) + global_element_counts[elem]["indices"].append({ + "sys_index": sys_index, + "frames": indices + }) else: global_element_counts[elem]["count"] += count - global_element_counts[elem]["indices"].append( - {"sys_index": sys_index, "frames": indices} - ) + global_element_counts[elem]["indices"].append({ + "sys_index": sys_index, + "frames": indices + }) else: - if ( - global_element_counts[elem]["count"] - >= min_frames_per_element_forstat - ): + if global_element_counts[elem]["count"] >= min_frames_per_element_forstat: pass else: global_element_counts[elem]["count"] += count @@ -157,21 +143,26 @@ def finalize_stats(sys_stat): "sys_index": sys_index, "frames": indices }) + # Complement if datasets[0].mixed_type and enable_element_completion: for elem, data in global_element_counts.items(): indices_count = data["count"] if indices_count < min_frames_per_element_forstat: log.warning( - f"The number of frames with element {elem} is {indices_count}, " - f"which is less than the required {min_frames_per_element_forstat}" + f'The number of frames in your datasets with element {elem} is {indices_count}, ' + f'which is less than the required {min_frames_per_element_forstat}' ) + collect_elements = collect_ele.keys() missing_elements = total_element_types - collect_elements + for ele, count in collect_ele.items(): + if count < min_frames_per_element_forstat: + missing_elements.add(ele) for miss in missing_elements: - sys_indices = global_element_counts[miss].get("indices", []) + sys_indices = global_element_counts[miss].get('indices', []) newele_counter = 0 for sys_info in sys_indices: - sys_index = sys_info["sys_index"] - frames = sys_info["frames"] + sys_index = sys_info['sys_index'] + frames = sys_info['frames'] sys = datasets[sys_index] for frame in frames: newele_counter += 1 @@ -193,23 +184,12 @@ def finalize_stats(sys_stat): sys_stat_new[dd] = frame_data[dd] else: pass - for key in sys_stat_new: - if isinstance(sys_stat_new[key], np.float32): - pass - elif ( - sys_stat_new[key] is None - or sys_stat_new[key][0] is None - ): - sys_stat_new[key] = None - elif isinstance(sys_stat_new[key][0], torch.Tensor): - sys_stat_new[key] = torch.cat(sys_stat_new[key], dim=0) - dict_to_device(sys_stat_new) + finalize_stats(sys_stat_new) lst.append(sys_stat_new) else: break return lst - def _restore_from_file( stat_file_path: DPPath, keys: list[str] = ["energy"], From 0c7baa0710ec74ab19d6a7378c875ff6ec53a74c Mon Sep 17 00:00:00 2001 From: SumGuo Date: Sat, 4 Jan 2025 13:36:37 +0800 Subject: [PATCH 35/82] ll --- deepmd/pt/utils/stat.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/deepmd/pt/utils/stat.py b/deepmd/pt/utils/stat.py index 5f147c5cd7..6fe87e3256 100644 --- a/deepmd/pt/utils/stat.py +++ b/deepmd/pt/utils/stat.py @@ -1,5 +1,7 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import logging +import numpy as np +import torch from collections import ( defaultdict, ) @@ -8,10 +10,6 @@ Optional, Union, ) - -import numpy as np -import torch - from deepmd.dpmodel.output_def import ( FittingOutputDef, ) @@ -143,6 +141,7 @@ def finalize_stats(sys_stat): "sys_index": sys_index, "frames": indices }) + # Complement if datasets[0].mixed_type and enable_element_completion: for elem, data in global_element_counts.items(): indices_count = data["count"] @@ -501,7 +500,6 @@ def compute_output_stats( std_atom_e = {kk: to_torch_tensor(vv) for kk, vv in std_atom_e.items()} return bias_atom_e, std_atom_e - def compute_output_stats_global( sampled: list[dict], ntypes: int, From 28d94afc617c15a02875d8393d9a6082d7b329c4 Mon Sep 17 00:00:00 2001 From: SumGuo Date: Sat, 4 Jan 2025 14:00:32 +0800 Subject: [PATCH 36/82] ll --- deepmd/pt/utils/stat.py | 66 +++++++++++++++++++++++------------------ 1 file changed, 37 insertions(+), 29 deletions(-) diff --git a/deepmd/pt/utils/stat.py b/deepmd/pt/utils/stat.py index 6fe87e3256..e868b95ea8 100644 --- a/deepmd/pt/utils/stat.py +++ b/deepmd/pt/utils/stat.py @@ -85,6 +85,36 @@ def process_batches(dataloader, sys_stat): else: pass + def process_with_oneframe(sys_indices, newele_counter): + for sys_info in sys_indices: + sys_index = sys_info['sys_index'] + frames = sys_info['frames'] + sys = datasets[sys_index] + for frame in frames: + newele_counter += 1 + if not newele_counter > min_frames_per_element_forstat: + frame_data = sys.__getitem__(frame) + sys_stat_new = {} + for dd in frame_data: + if dd == "type": + continue + if frame_data[dd] is None: + sys_stat_new[dd] = None + elif isinstance(frame_data[dd], np.ndarray): + if dd not in sys_stat_new: + sys_stat_new[dd] = [] + tensor_data = torch.from_numpy(frame_data[dd]) + tensor_data = tensor_data.unsqueeze(0) + sys_stat_new[dd].append(tensor_data) + elif isinstance(frame_data[dd], np.float32): + sys_stat_new[dd] = frame_data[dd] + else: + pass + finalize_stats(sys_stat_new) + lst.append(sys_stat_new) + else: + break + def finalize_stats(sys_stat): """Finalize statistics by concatenating tensors.""" for key in sys_stat: @@ -152,40 +182,18 @@ def finalize_stats(sys_stat): ) collect_elements = collect_ele.keys() missing_elements = total_element_types - collect_elements + collect_miss_element = set() for ele, count in collect_ele.items(): if count < min_frames_per_element_forstat: + collect_miss_element.add(ele) missing_elements.add(ele) for miss in missing_elements: sys_indices = global_element_counts[miss].get('indices', []) - newele_counter = 0 - for sys_info in sys_indices: - sys_index = sys_info['sys_index'] - frames = sys_info['frames'] - sys = datasets[sys_index] - for frame in frames: - newele_counter += 1 - if not newele_counter > min_frames_per_element_forstat: - frame_data = sys.__getitem__(frame) - sys_stat_new = {} - for dd in frame_data: - if dd == "type": - continue - if frame_data[dd] is None: - sys_stat_new[dd] = None - elif isinstance(frame_data[dd], np.ndarray): - if dd not in sys_stat_new: - sys_stat_new[dd] = [] - tensor_data = torch.from_numpy(frame_data[dd]) - tensor_data = tensor_data.unsqueeze(0) - sys_stat_new[dd].append(tensor_data) - elif isinstance(frame_data[dd], np.float32): - sys_stat_new[dd] = frame_data[dd] - else: - pass - finalize_stats(sys_stat_new) - lst.append(sys_stat_new) - else: - break + if miss in collect_miss_element: + newele_counter = collect_ele.get(miss, 0) + else: + newele_counter = 0 + process_with_oneframe(sys_indices,newele_counter) return lst def _restore_from_file( From 87dcd66c658707066ff0a5837bcc58c19cbc00b7 Mon Sep 17 00:00:00 2001 From: SumGuo Date: Sat, 4 Jan 2025 14:02:32 +0800 Subject: [PATCH 37/82] l --- source/tests/pt/test_make_stat_input.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/source/tests/pt/test_make_stat_input.py b/source/tests/pt/test_make_stat_input.py index cf673637e3..2ccefb3f9e 100644 --- a/source/tests/pt/test_make_stat_input.py +++ b/source/tests/pt/test_make_stat_input.py @@ -37,6 +37,7 @@ class TestMakeStatInput(unittest.TestCase): @classmethod def setUpClass(cls): system_path = "mixed_type_data/sys.000000" + cls.real_ntypes = 6 cls.datasets = DeepmdDataSetForLoader(system=system_path) data_requirements = [ DataRequirementItem( @@ -81,8 +82,8 @@ def test_make_stat_input(self): non_zero_count = self.count_non_zero_elements(energy) self.assertEqual( non_zero_count, - 6, - f"Expected exactly 7 non-zero elements, but got {non_zero_count}.", + self.real_ntypes, + f"Expected exactly {self.real_ntypes} non-zero elements, but got {non_zero_count}.", ) def test_make_stat_input_nocomplete(self): @@ -99,8 +100,8 @@ def test_make_stat_input_nocomplete(self): non_zero_count = self.count_non_zero_elements(energy) self.assertLess( non_zero_count, - 6, - f"Expected fewer than 7 non-zero elements, but got {non_zero_count}.", + self.real_ntypes, + f"Expected fewer than {self.real_ntypes} non-zero elements, but got {non_zero_count}.", ) From 5d33060ee8af0d44ef20ce3180668497860152a4 Mon Sep 17 00:00:00 2001 From: SumGuo Date: Sat, 4 Jan 2025 14:03:45 +0800 Subject: [PATCH 38/82] ll --- deepmd/pt/utils/stat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepmd/pt/utils/stat.py b/deepmd/pt/utils/stat.py index e868b95ea8..7b9f25c061 100644 --- a/deepmd/pt/utils/stat.py +++ b/deepmd/pt/utils/stat.py @@ -85,7 +85,7 @@ def process_batches(dataloader, sys_stat): else: pass - def process_with_oneframe(sys_indices, newele_counter): + def process_with_newframe(sys_indices, newele_counter): for sys_info in sys_indices: sys_index = sys_info['sys_index'] frames = sys_info['frames'] From 521e3a6f87dd8efe52bf64e4c0395fcf9e0b697d Mon Sep 17 00:00:00 2001 From: SumGuo Date: Sat, 4 Jan 2025 14:04:15 +0800 Subject: [PATCH 39/82] ll --- deepmd/pt/utils/stat.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/deepmd/pt/utils/stat.py b/deepmd/pt/utils/stat.py index 7b9f25c061..f0af6e4d86 100644 --- a/deepmd/pt/utils/stat.py +++ b/deepmd/pt/utils/stat.py @@ -85,7 +85,7 @@ def process_batches(dataloader, sys_stat): else: pass - def process_with_newframe(sys_indices, newele_counter): + def process_with_new_frame(sys_indices, newele_counter): for sys_info in sys_indices: sys_index = sys_info['sys_index'] frames = sys_info['frames'] @@ -193,7 +193,7 @@ def finalize_stats(sys_stat): newele_counter = collect_ele.get(miss, 0) else: newele_counter = 0 - process_with_oneframe(sys_indices,newele_counter) + process_with_new_frame(sys_indices,newele_counter) return lst def _restore_from_file( From 0dabf77f92e409967133f71e7c341a2c9721c73a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 5 Jan 2025 05:50:17 +0000 Subject: [PATCH 40/82] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- deepmd/pt/utils/stat.py | 80 +++++++++++++++++++++++------------------ 1 file changed, 45 insertions(+), 35 deletions(-) diff --git a/deepmd/pt/utils/stat.py b/deepmd/pt/utils/stat.py index 9fb2db5c41..d246708e48 100644 --- a/deepmd/pt/utils/stat.py +++ b/deepmd/pt/utils/stat.py @@ -1,7 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import logging -import numpy as np -import torch from collections import ( defaultdict, ) @@ -11,13 +9,9 @@ Union, ) -from deepmd.dpmodel.output_def import ( - FittingOutputDef, -) - - import numpy as np import torch + from deepmd.pt.utils import ( AtomExcludeMask, ) @@ -40,7 +34,14 @@ log = logging.getLogger(__name__) -def make_stat_input(datasets, dataloaders, nbatches, min_frames_per_element_forstat, enable_element_completion=True): + +def make_stat_input( + datasets, + dataloaders, + nbatches, + min_frames_per_element_forstat, + enable_element_completion=True, +): """Pack data for statistics. Element checking is only enabled with mixed_type. @@ -63,11 +64,13 @@ def make_stat_input(datasets, dataloaders, nbatches, min_frames_per_element_fors if datasets[0].mixed_type: if enable_element_completion: log.info( - f'Element check enabled. ' - f'Verifying if frames with elements meet the set of {min_frames_per_element_forstat}.' + f"Element check enabled. " + f"Verifying if frames with elements meet the set of {min_frames_per_element_forstat}." ) else: - log.info("Element completion is disabled. Skipping missing element handling.") + log.info( + "Element completion is disabled. Skipping missing element handling." + ) def process_batches(dataloader, sys_stat): """Process batches from a dataloader to collect statistics.""" @@ -93,8 +96,8 @@ def process_batches(dataloader, sys_stat): def process_with_new_frame(sys_indices, newele_counter): for sys_info in sys_indices: - sys_index = sys_info['sys_index'] - frames = sys_info['frames'] + sys_index = sys_info["sys_index"] + frames = sys_info["frames"] sys = datasets[sys_index] for frame in frames: newele_counter += 1 @@ -126,7 +129,10 @@ def finalize_stats(sys_stat): for key in sys_stat: if isinstance(sys_stat[key], np.float32): pass - elif sys_stat[key] is None or (isinstance(sys_stat[key], list) and (len(sys_stat[key]) == 0 or sys_stat[key][0] is None)): + elif sys_stat[key] is None or ( + isinstance(sys_stat[key], list) + and (len(sys_stat[key]) == 0 or sys_stat[key][0] is None) + ): sys_stat[key] = None elif isinstance(sys_stat[key][0], torch.Tensor): sys_stat[key] = torch.cat(sys_stat[key], dim=0) @@ -137,7 +143,7 @@ def finalize_stats(sys_stat): with torch.device("cpu"): process_batches(dataloader, sys_stat) if datasets[0].mixed_type and enable_element_completion: - element_data = torch.cat(sys_stat['atype'], dim=0) + element_data = torch.cat(sys_stat["atype"], dim=0) collect_values = torch.unique(element_data.flatten(), sorted=True) for elem in collect_values.tolist(): frames_with_elem = torch.any(element_data == elem, dim=1) @@ -145,8 +151,8 @@ def finalize_stats(sys_stat): collect_ele[elem] += len(row_indices) finalize_stats(sys_stat) lst.append(sys_stat) - - #get frame index + + # get frame index if datasets[0].mixed_type and enable_element_completion: element_counts = dataset.get_frame_index() for elem, data in element_counts.items(): @@ -156,35 +162,37 @@ def finalize_stats(sys_stat): if elem not in global_element_counts: global_element_counts[elem] = {"count": 0, "indices": []} if count > min_frames_per_element_forstat: - global_element_counts[elem]["count"] += min_frames_per_element_forstat + global_element_counts[elem]["count"] += ( + min_frames_per_element_forstat + ) indices = indices[:min_frames_per_element_forstat] - global_element_counts[elem]["indices"].append({ - "sys_index": sys_index, - "frames": indices - }) + global_element_counts[elem]["indices"].append( + {"sys_index": sys_index, "frames": indices} + ) else: global_element_counts[elem]["count"] += count - global_element_counts[elem]["indices"].append({ - "sys_index": sys_index, - "frames": indices - }) + global_element_counts[elem]["indices"].append( + {"sys_index": sys_index, "frames": indices} + ) else: - if global_element_counts[elem]["count"] >= min_frames_per_element_forstat: + if ( + global_element_counts[elem]["count"] + >= min_frames_per_element_forstat + ): pass else: global_element_counts[elem]["count"] += count - global_element_counts[elem]["indices"].append({ - "sys_index": sys_index, - "frames": indices - }) + global_element_counts[elem]["indices"].append( + {"sys_index": sys_index, "frames": indices} + ) # Complement if datasets[0].mixed_type and enable_element_completion: for elem, data in global_element_counts.items(): indices_count = data["count"] if indices_count < min_frames_per_element_forstat: log.warning( - f'The number of frames in your datasets with element {elem} is {indices_count}, ' - f'which is less than the required {min_frames_per_element_forstat}' + f"The number of frames in your datasets with element {elem} is {indices_count}, " + f"which is less than the required {min_frames_per_element_forstat}" ) collect_elements = collect_ele.keys() missing_elements = total_element_types - collect_elements @@ -194,14 +202,15 @@ def finalize_stats(sys_stat): collect_miss_element.add(ele) missing_elements.add(ele) for miss in missing_elements: - sys_indices = global_element_counts[miss].get('indices', []) + sys_indices = global_element_counts[miss].get("indices", []) if miss in collect_miss_element: newele_counter = collect_ele.get(miss, 0) else: newele_counter = 0 - process_with_new_frame(sys_indices,newele_counter) + process_with_new_frame(sys_indices, newele_counter) return lst + def _restore_from_file( stat_file_path: DPPath, keys: list[str] = ["energy"], @@ -523,6 +532,7 @@ def compute_output_stats( std_atom_e = {kk: to_torch_tensor(vv) for kk, vv in std_atom_e.items()} return bias_atom_e, std_atom_e + def compute_output_stats_global( sampled: list[dict], ntypes: int, From a23528c8c572e8d8577c01411673f8fac221c1c7 Mon Sep 17 00:00:00 2001 From: Yuliang Guo Date: Sun, 5 Jan 2025 13:51:47 +0800 Subject: [PATCH 41/82] Update stat.py fix conflict Signed-off-by: Yuliang Guo --- deepmd/pt/utils/stat.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/deepmd/pt/utils/stat.py b/deepmd/pt/utils/stat.py index d246708e48..16c300b13a 100644 --- a/deepmd/pt/utils/stat.py +++ b/deepmd/pt/utils/stat.py @@ -8,10 +8,6 @@ Optional, Union, ) - -import numpy as np -import torch - from deepmd.pt.utils import ( AtomExcludeMask, ) From 0a97b547e8f689f2d5be69e194c7927f21444859 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 5 Jan 2025 05:53:09 +0000 Subject: [PATCH 42/82] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- deepmd/pt/utils/stat.py | 1 + 1 file changed, 1 insertion(+) diff --git a/deepmd/pt/utils/stat.py b/deepmd/pt/utils/stat.py index 16c300b13a..26f160db00 100644 --- a/deepmd/pt/utils/stat.py +++ b/deepmd/pt/utils/stat.py @@ -8,6 +8,7 @@ Optional, Union, ) + from deepmd.pt.utils import ( AtomExcludeMask, ) From 49744eda93dc9cb1af0309ba4fd7015216403eb5 Mon Sep 17 00:00:00 2001 From: Yuliang Guo Date: Mon, 6 Jan 2025 09:33:46 +0800 Subject: [PATCH 43/82] Update deepmd/pt/utils/stat.py Co-authored-by: Duo <50307526+iProzd@users.noreply.github.com> Signed-off-by: Yuliang Guo --- deepmd/pt/utils/stat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepmd/pt/utils/stat.py b/deepmd/pt/utils/stat.py index 26f160db00..ceb790c792 100644 --- a/deepmd/pt/utils/stat.py +++ b/deepmd/pt/utils/stat.py @@ -98,7 +98,7 @@ def process_with_new_frame(sys_indices, newele_counter): sys = datasets[sys_index] for frame in frames: newele_counter += 1 - if not newele_counter > min_frames_per_element_forstat: + if newele_counter <= min_frames_per_element_forstat: frame_data = sys.__getitem__(frame) sys_stat_new = {} for dd in frame_data: From 556a68481439c93fdc6c84b04fcc6e9f5f492202 Mon Sep 17 00:00:00 2001 From: SumGuo Date: Mon, 6 Jan 2025 09:42:50 +0800 Subject: [PATCH 44/82] Simplify logic and remove "not" --- deepmd/pt/utils/stat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepmd/pt/utils/stat.py b/deepmd/pt/utils/stat.py index f0af6e4d86..6406cd4361 100644 --- a/deepmd/pt/utils/stat.py +++ b/deepmd/pt/utils/stat.py @@ -92,7 +92,7 @@ def process_with_new_frame(sys_indices, newele_counter): sys = datasets[sys_index] for frame in frames: newele_counter += 1 - if not newele_counter > min_frames_per_element_forstat: + if newele_counter <= min_frames_per_element_forstat: frame_data = sys.__getitem__(frame) sys_stat_new = {} for dd in frame_data: From 27999af02ef59f6b4081eb3a134c7dd6a63d6121 Mon Sep 17 00:00:00 2001 From: SumGuo Date: Mon, 6 Jan 2025 09:48:50 +0800 Subject: [PATCH 45/82] check import --- deepmd/pt/utils/stat.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/deepmd/pt/utils/stat.py b/deepmd/pt/utils/stat.py index ceb790c792..3c47b1c268 100644 --- a/deepmd/pt/utils/stat.py +++ b/deepmd/pt/utils/stat.py @@ -1,5 +1,7 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import logging +import numpy as np +import torch from collections import ( defaultdict, ) @@ -8,7 +10,9 @@ Optional, Union, ) - +from deepmd.dpmodel.output_def import ( + FittingOutputDef, +) from deepmd.pt.utils import ( AtomExcludeMask, ) From 83b7f1d894ade4452454f51e67ffee86160a71f0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 6 Jan 2025 01:50:16 +0000 Subject: [PATCH 46/82] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- deepmd/pt/utils/stat.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/deepmd/pt/utils/stat.py b/deepmd/pt/utils/stat.py index 3c47b1c268..455345c010 100644 --- a/deepmd/pt/utils/stat.py +++ b/deepmd/pt/utils/stat.py @@ -1,7 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import logging -import numpy as np -import torch from collections import ( defaultdict, ) @@ -10,9 +8,10 @@ Optional, Union, ) -from deepmd.dpmodel.output_def import ( - FittingOutputDef, -) + +import numpy as np +import torch + from deepmd.pt.utils import ( AtomExcludeMask, ) From 817d2ec2285d334fcfe212308802de119a27bcf3 Mon Sep 17 00:00:00 2001 From: SumGuo Date: Mon, 6 Jan 2025 10:00:46 +0800 Subject: [PATCH 47/82] Add assert to ensure that the new frame contains the required elements --- deepmd/pt/utils/stat.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/deepmd/pt/utils/stat.py b/deepmd/pt/utils/stat.py index 3c47b1c268..8ec8f7a5cb 100644 --- a/deepmd/pt/utils/stat.py +++ b/deepmd/pt/utils/stat.py @@ -1,7 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import logging -import numpy as np -import torch from collections import ( defaultdict, ) @@ -10,6 +8,8 @@ Optional, Union, ) +import numpy as np +import torch from deepmd.dpmodel.output_def import ( FittingOutputDef, ) @@ -35,7 +35,6 @@ log = logging.getLogger(__name__) - def make_stat_input( datasets, dataloaders, @@ -95,7 +94,7 @@ def process_batches(dataloader, sys_stat): else: pass - def process_with_new_frame(sys_indices, newele_counter): + def process_with_new_frame(sys_indices, newele_counter, miss): for sys_info in sys_indices: sys_index = sys_info["sys_index"] frames = sys_info["frames"] @@ -104,6 +103,7 @@ def process_with_new_frame(sys_indices, newele_counter): newele_counter += 1 if newele_counter <= min_frames_per_element_forstat: frame_data = sys.__getitem__(frame) + assert miss in frame_data['atype'], f"Missing element '{miss}' not found in frame data." sys_stat_new = {} for dd in frame_data: if dd == "type": @@ -208,10 +208,9 @@ def finalize_stats(sys_stat): newele_counter = collect_ele.get(miss, 0) else: newele_counter = 0 - process_with_new_frame(sys_indices, newele_counter) + process_with_new_frame(sys_indices, newele_counter, miss) return lst - def _restore_from_file( stat_file_path: DPPath, keys: list[str] = ["energy"], From 93a748f4da8d5583f15f724bfb3eb9a9a8522d64 Mon Sep 17 00:00:00 2001 From: SumGuo Date: Mon, 6 Jan 2025 10:03:07 +0800 Subject: [PATCH 48/82] check import --- deepmd/pt/utils/stat.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/deepmd/pt/utils/stat.py b/deepmd/pt/utils/stat.py index 173e8b4528..b73ecca229 100644 --- a/deepmd/pt/utils/stat.py +++ b/deepmd/pt/utils/stat.py @@ -1,5 +1,7 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import logging +import numpy as np +import torch from collections import ( defaultdict, ) @@ -8,18 +10,11 @@ Optional, Union, ) -<<<<<<< HEAD import numpy as np import torch from deepmd.dpmodel.output_def import ( FittingOutputDef, ) -======= - -import numpy as np -import torch - ->>>>>>> 83b7f1d894ade4452454f51e67ffee86160a71f0 from deepmd.pt.utils import ( AtomExcludeMask, ) From 4a38f1deffb0456bd6735ffa532dab5dd6a544a1 Mon Sep 17 00:00:00 2001 From: SumGuo Date: Mon, 6 Jan 2025 10:04:13 +0800 Subject: [PATCH 49/82] check import --- deepmd/pt/utils/stat.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/deepmd/pt/utils/stat.py b/deepmd/pt/utils/stat.py index b73ecca229..8ec8f7a5cb 100644 --- a/deepmd/pt/utils/stat.py +++ b/deepmd/pt/utils/stat.py @@ -1,7 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import logging -import numpy as np -import torch from collections import ( defaultdict, ) From 78b2a10477d13858b4c2ba0f70151bdbb1165b44 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 6 Jan 2025 02:04:31 +0000 Subject: [PATCH 50/82] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- deepmd/pt/utils/stat.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/deepmd/pt/utils/stat.py b/deepmd/pt/utils/stat.py index 8ec8f7a5cb..822ce1be04 100644 --- a/deepmd/pt/utils/stat.py +++ b/deepmd/pt/utils/stat.py @@ -8,11 +8,10 @@ Optional, Union, ) + import numpy as np import torch -from deepmd.dpmodel.output_def import ( - FittingOutputDef, -) + from deepmd.pt.utils import ( AtomExcludeMask, ) @@ -35,6 +34,7 @@ log = logging.getLogger(__name__) + def make_stat_input( datasets, dataloaders, @@ -103,7 +103,9 @@ def process_with_new_frame(sys_indices, newele_counter, miss): newele_counter += 1 if newele_counter <= min_frames_per_element_forstat: frame_data = sys.__getitem__(frame) - assert miss in frame_data['atype'], f"Missing element '{miss}' not found in frame data." + assert ( + miss in frame_data["atype"] + ), f"Missing element '{miss}' not found in frame data." sys_stat_new = {} for dd in frame_data: if dd == "type": @@ -211,6 +213,7 @@ def finalize_stats(sys_stat): process_with_new_frame(sys_indices, newele_counter, miss) return lst + def _restore_from_file( stat_file_path: DPPath, keys: list[str] = ["energy"], From 6a5d169150249a8b5c311fe18b83aa83d0db2b86 Mon Sep 17 00:00:00 2001 From: SumGuo Date: Mon, 6 Jan 2025 11:18:54 +0800 Subject: [PATCH 51/82] check test.py --- source/tests/pt/test_make_stat_input.py | 44 +++++++++++++++++++++++-- 1 file changed, 42 insertions(+), 2 deletions(-) diff --git a/source/tests/pt/test_make_stat_input.py b/source/tests/pt/test_make_stat_input.py index 2ccefb3f9e..3a095ee2b1 100644 --- a/source/tests/pt/test_make_stat_input.py +++ b/source/tests/pt/test_make_stat_input.py @@ -69,6 +69,7 @@ def count_non_zero_elements(self, tensor, threshold=1e-8): return torch.sum(torch.abs(tensor) > threshold).item() def test_make_stat_input(self): + #3 frames would be count lst = make_stat_input( datasets=self.datasets, dataloaders=self.dataloaders, @@ -78,7 +79,7 @@ def test_make_stat_input(self): ) bias, _ = compute_output_stats(lst, ntypes=57) energy = bias.get("energy") - self.assertIsNotNone(energy, "'energy' key not found in bias dictionary.") + print(energy) non_zero_count = self.count_non_zero_elements(energy) self.assertEqual( non_zero_count, @@ -87,6 +88,9 @@ def test_make_stat_input(self): ) def test_make_stat_input_nocomplete(self): + #missing element:13,31,37 + #only one frame would be count + lst = make_stat_input( datasets=self.datasets, dataloaders=self.dataloaders, @@ -96,7 +100,7 @@ def test_make_stat_input_nocomplete(self): ) bias, _ = compute_output_stats(lst, ntypes=57) energy = bias.get("energy") - self.assertIsNotNone(energy, "'energy' key not found in bias dictionary.") + print(energy) non_zero_count = self.count_non_zero_elements(energy) self.assertLess( non_zero_count, @@ -104,6 +108,42 @@ def test_make_stat_input_nocomplete(self): f"Expected fewer than {self.real_ntypes} non-zero elements, but got {non_zero_count}.", ) + def test_bias(self): + lst_ori = make_stat_input( + datasets=self.datasets, + dataloaders=self.dataloaders, + nbatches=1, + min_frames_per_element_forstat=1, + enable_element_completion=False, + ) + lst_all = make_stat_input( + datasets=self.datasets, + dataloaders=self.dataloaders, + nbatches=1, + min_frames_per_element_forstat=1, + enable_element_completion=True, + ) + bias_ori, _ = compute_output_stats(lst_ori, ntypes=57) + bias_all, _ = compute_output_stats(lst_all, ntypes=57) + energy_ori = np.array(bias_ori.get("energy").cpu()).flatten() + energy_all = np.array(bias_all.get("energy").cpu()).flatten() + + for i, (e_ori, e_all) in enumerate(zip(energy_ori, energy_all)): + if e_all == 0: + self.assertEqual( + e_ori, + 0, + f"Index {i}: energy_all=0, but energy_ori={e_ori}" + ) + else: + if e_ori != 0: + diff = abs(e_ori - e_all) + rel_diff = diff / abs(e_ori) + self.assertTrue( + rel_diff < 0.4, + f"Index {i}: energy_ori={e_ori}, energy_all={e_all}, " + f"relative difference {rel_diff:.2%} is too large" + ) if __name__ == "__main__": unittest.main() From 26205d74c44201c56d39241b75e26ae381fcf67f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 6 Jan 2025 03:20:28 +0000 Subject: [PATCH 52/82] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- source/tests/pt/test_make_stat_input.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/source/tests/pt/test_make_stat_input.py b/source/tests/pt/test_make_stat_input.py index 3a095ee2b1..220e2fd75a 100644 --- a/source/tests/pt/test_make_stat_input.py +++ b/source/tests/pt/test_make_stat_input.py @@ -69,7 +69,7 @@ def count_non_zero_elements(self, tensor, threshold=1e-8): return torch.sum(torch.abs(tensor) > threshold).item() def test_make_stat_input(self): - #3 frames would be count + # 3 frames would be count lst = make_stat_input( datasets=self.datasets, dataloaders=self.dataloaders, @@ -88,8 +88,8 @@ def test_make_stat_input(self): ) def test_make_stat_input_nocomplete(self): - #missing element:13,31,37 - #only one frame would be count + # missing element:13,31,37 + # only one frame would be count lst = make_stat_input( datasets=self.datasets, @@ -127,13 +127,11 @@ def test_bias(self): bias_all, _ = compute_output_stats(lst_all, ntypes=57) energy_ori = np.array(bias_ori.get("energy").cpu()).flatten() energy_all = np.array(bias_all.get("energy").cpu()).flatten() - + for i, (e_ori, e_all) in enumerate(zip(energy_ori, energy_all)): if e_all == 0: self.assertEqual( - e_ori, - 0, - f"Index {i}: energy_all=0, but energy_ori={e_ori}" + e_ori, 0, f"Index {i}: energy_all=0, but energy_ori={e_ori}" ) else: if e_ori != 0: @@ -142,8 +140,9 @@ def test_bias(self): self.assertTrue( rel_diff < 0.4, f"Index {i}: energy_ori={e_ori}, energy_all={e_all}, " - f"relative difference {rel_diff:.2%} is too large" + f"relative difference {rel_diff:.2%} is too large", ) + if __name__ == "__main__": unittest.main() From 3ccb4b9942b68fdc1da3b2eaa87e2d4a87f8d776 Mon Sep 17 00:00:00 2001 From: SumGuo Date: Mon, 6 Jan 2025 17:59:01 +0800 Subject: [PATCH 53/82] check ut --- source/tests/pt/test_make_stat_input.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/source/tests/pt/test_make_stat_input.py b/source/tests/pt/test_make_stat_input.py index 3a095ee2b1..3c9f93e99d 100644 --- a/source/tests/pt/test_make_stat_input.py +++ b/source/tests/pt/test_make_stat_input.py @@ -79,7 +79,6 @@ def test_make_stat_input(self): ) bias, _ = compute_output_stats(lst, ntypes=57) energy = bias.get("energy") - print(energy) non_zero_count = self.count_non_zero_elements(energy) self.assertEqual( non_zero_count, @@ -88,7 +87,7 @@ def test_make_stat_input(self): ) def test_make_stat_input_nocomplete(self): - #missing element:13,31,37 + #missing element:13,31,37 #only one frame would be count lst = make_stat_input( @@ -100,7 +99,6 @@ def test_make_stat_input_nocomplete(self): ) bias, _ = compute_output_stats(lst, ntypes=57) energy = bias.get("energy") - print(energy) non_zero_count = self.count_non_zero_elements(energy) self.assertLess( non_zero_count, From 87de0e8d780bc04615c20a2650b7d01c93736aa8 Mon Sep 17 00:00:00 2001 From: SumGuo Date: Mon, 6 Jan 2025 18:03:12 +0800 Subject: [PATCH 54/82] check ut --- source/tests/pt/test_make_stat_input.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/source/tests/pt/test_make_stat_input.py b/source/tests/pt/test_make_stat_input.py index c546bb1017..96009ef7cf 100644 --- a/source/tests/pt/test_make_stat_input.py +++ b/source/tests/pt/test_make_stat_input.py @@ -87,13 +87,8 @@ def test_make_stat_input(self): ) def test_make_stat_input_nocomplete(self): -<<<<<<< HEAD #missing element:13,31,37 #only one frame would be count -======= - # missing element:13,31,37 - # only one frame would be count ->>>>>>> 26205d74c44201c56d39241b75e26ae381fcf67f lst = make_stat_input( datasets=self.datasets, @@ -146,6 +141,5 @@ def test_bias(self): f"relative difference {rel_diff:.2%} is too large", ) - if __name__ == "__main__": unittest.main() From 0939ef1e9b05c4ca8d0dcb994dfb467c61287e03 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 6 Jan 2025 10:05:44 +0000 Subject: [PATCH 55/82] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- source/tests/pt/test_make_stat_input.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/source/tests/pt/test_make_stat_input.py b/source/tests/pt/test_make_stat_input.py index 96009ef7cf..12dbe6d9e7 100644 --- a/source/tests/pt/test_make_stat_input.py +++ b/source/tests/pt/test_make_stat_input.py @@ -87,8 +87,8 @@ def test_make_stat_input(self): ) def test_make_stat_input_nocomplete(self): - #missing element:13,31,37 - #only one frame would be count + # missing element:13,31,37 + # only one frame would be count lst = make_stat_input( datasets=self.datasets, @@ -141,5 +141,6 @@ def test_bias(self): f"relative difference {rel_diff:.2%} is too large", ) + if __name__ == "__main__": unittest.main() From 7ec779fef7c894080e0ac31c478ba7c312d21730 Mon Sep 17 00:00:00 2001 From: Yuliang Guo Date: Tue, 7 Jan 2025 13:37:47 +0800 Subject: [PATCH 56/82] Update deepmd/utils/argcheck.py Co-authored-by: Duo <50307526+iProzd@users.noreply.github.com> Signed-off-by: Yuliang Guo --- deepmd/utils/argcheck.py | 1 + 1 file changed, 1 insertion(+) diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index ad1b8f8839..474a9e12a6 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -2896,6 +2896,7 @@ def training_args( Argument( "min_frames_per_element_forstat", int, + default=10, optional=True, doc="The minimum number of frames per element used for statistics when using the mixed type.", ), From 24d1386195fefd7fdff316da33b54ad9b06e1924 Mon Sep 17 00:00:00 2001 From: Yuliang Guo Date: Tue, 7 Jan 2025 13:38:53 +0800 Subject: [PATCH 57/82] Update deepmd/utils/argcheck.py Co-authored-by: Duo <50307526+iProzd@users.noreply.github.com> Signed-off-by: Yuliang Guo --- deepmd/utils/argcheck.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 474a9e12a6..3a4091cd8a 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -2903,7 +2903,7 @@ def training_args( Argument( "enable_element_completion", bool, - optional=False, + optional=True, default=True, doc="Whether to check elements when using the mixed type", ), From 708bc78d0ad53a91a4cb2ae3fa7115fdc1c949ce Mon Sep 17 00:00:00 2001 From: SumGuo Date: Tue, 7 Jan 2025 14:24:27 +0800 Subject: [PATCH 58/82] check msi defalut value --- deepmd/pt/utils/stat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepmd/pt/utils/stat.py b/deepmd/pt/utils/stat.py index 822ce1be04..eb021fe760 100644 --- a/deepmd/pt/utils/stat.py +++ b/deepmd/pt/utils/stat.py @@ -39,7 +39,7 @@ def make_stat_input( datasets, dataloaders, nbatches, - min_frames_per_element_forstat, + min_frames_per_element_forstat = 10, enable_element_completion=True, ): """Pack data for statistics. From 02f3f28bf41dce37fb0cc8e714b4c3bfea1593d4 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 7 Jan 2025 06:25:58 +0000 Subject: [PATCH 59/82] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- deepmd/pt/utils/stat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepmd/pt/utils/stat.py b/deepmd/pt/utils/stat.py index eb021fe760..23ff776655 100644 --- a/deepmd/pt/utils/stat.py +++ b/deepmd/pt/utils/stat.py @@ -39,7 +39,7 @@ def make_stat_input( datasets, dataloaders, nbatches, - min_frames_per_element_forstat = 10, + min_frames_per_element_forstat=10, enable_element_completion=True, ): """Pack data for statistics. From d36a24ab4ab2408bf3c6335fa7a58d7cf57a79e6 Mon Sep 17 00:00:00 2001 From: SumGuo Date: Tue, 7 Jan 2025 17:17:04 +0800 Subject: [PATCH 60/82] check ut cuda --- source/tests/pt/test_make_stat_input.py | 70 ++++++++++++------------- 1 file changed, 35 insertions(+), 35 deletions(-) diff --git a/source/tests/pt/test_make_stat_input.py b/source/tests/pt/test_make_stat_input.py index 12dbe6d9e7..8196627685 100644 --- a/source/tests/pt/test_make_stat_input.py +++ b/source/tests/pt/test_make_stat_input.py @@ -1,11 +1,13 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import unittest - import numpy as np import torch from torch.utils.data import ( DataLoader, ) +from pathlib import ( + Path, +) from deepmd.pt.utils.dataset import ( DeepmdDataSetForLoader, @@ -17,8 +19,7 @@ from deepmd.utils.data import ( DataRequirementItem, ) - - +torch.cuda.set_device(0) def collate_fn(batch): if isinstance(batch, dict): batch = [batch] @@ -36,34 +37,37 @@ def collate_fn(batch): class TestMakeStatInput(unittest.TestCase): @classmethod def setUpClass(cls): - system_path = "mixed_type_data/sys.000000" - cls.real_ntypes = 6 - cls.datasets = DeepmdDataSetForLoader(system=system_path) - data_requirements = [ - DataRequirementItem( - "energy", - ndof=1, - atomic=False, - ), - ] - cls.datasets.add_data_requirement(data_requirements) - cls.datasets = [cls.datasets] - weights = torch.tensor([0.1] * len(cls.datasets)) - sampler = torch.utils.data.WeightedRandomSampler( - weights, num_samples=len(weights), replacement=True - ) - cls.dataloaders = [] - for dataset in cls.datasets: - dataloader = DataLoader( - dataset, - sampler=sampler, - batch_size=1, - num_workers=0, - drop_last=False, - collate_fn=collate_fn, - pin_memory=True, + with torch.device("cpu"): + system_path = str(Path(__file__).parent / "mixed_type_data/sys.000000") + cls.real_ntypes = 6 + cls.datasets = DeepmdDataSetForLoader(system=system_path) + data_requirements = [ + DataRequirementItem( + "energy", + ndof=1, + atomic=False, + ), + ] + cls.datasets.add_data_requirement(data_requirements) + cls.datasets = [cls.datasets] + weights_tensor = torch.tensor([0.1] * len(cls.datasets),dtype=torch.float64, device="cpu") + sampler = torch.utils.data.WeightedRandomSampler( + weights_tensor, + num_samples=len(cls.datasets), + replacement=True, ) - cls.dataloaders.append(dataloader) + cls.dataloaders = [] + for dataset in cls.datasets: + dataloader = DataLoader( + dataset, + sampler=sampler, + batch_size=1, + num_workers=0, + drop_last=False, + collate_fn=collate_fn, + pin_memory=False, + ) + cls.dataloaders.append(dataloader) def count_non_zero_elements(self, tensor, threshold=1e-8): return torch.sum(torch.abs(tensor) > threshold).item() @@ -139,8 +143,4 @@ def test_bias(self): rel_diff < 0.4, f"Index {i}: energy_ori={e_ori}, energy_all={e_all}, " f"relative difference {rel_diff:.2%} is too large", - ) - - -if __name__ == "__main__": - unittest.main() + ) \ No newline at end of file From 050dbaf6b0abc7a2165c86bd0c3d88c585b68600 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 7 Jan 2025 09:18:32 +0000 Subject: [PATCH 61/82] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- source/tests/pt/test_make_stat_input.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/source/tests/pt/test_make_stat_input.py b/source/tests/pt/test_make_stat_input.py index 8196627685..2e309c6dc6 100644 --- a/source/tests/pt/test_make_stat_input.py +++ b/source/tests/pt/test_make_stat_input.py @@ -1,13 +1,14 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import unittest +from pathlib import ( + Path, +) + import numpy as np import torch from torch.utils.data import ( DataLoader, ) -from pathlib import ( - Path, -) from deepmd.pt.utils.dataset import ( DeepmdDataSetForLoader, @@ -19,7 +20,10 @@ from deepmd.utils.data import ( DataRequirementItem, ) + torch.cuda.set_device(0) + + def collate_fn(batch): if isinstance(batch, dict): batch = [batch] @@ -50,7 +54,9 @@ def setUpClass(cls): ] cls.datasets.add_data_requirement(data_requirements) cls.datasets = [cls.datasets] - weights_tensor = torch.tensor([0.1] * len(cls.datasets),dtype=torch.float64, device="cpu") + weights_tensor = torch.tensor( + [0.1] * len(cls.datasets), dtype=torch.float64, device="cpu" + ) sampler = torch.utils.data.WeightedRandomSampler( weights_tensor, num_samples=len(cls.datasets), @@ -143,4 +149,4 @@ def test_bias(self): rel_diff < 0.4, f"Index {i}: energy_ori={e_ori}, energy_all={e_all}, " f"relative difference {rel_diff:.2%} is too large", - ) \ No newline at end of file + ) From b00f8de4be30670c1214c669e6a44356fb731f64 Mon Sep 17 00:00:00 2001 From: SumGuo Date: Tue, 7 Jan 2025 18:42:21 +0800 Subject: [PATCH 62/82] check ut --- source/tests/pt/test_make_stat_input.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/source/tests/pt/test_make_stat_input.py b/source/tests/pt/test_make_stat_input.py index 8196627685..0a42f2b3b4 100644 --- a/source/tests/pt/test_make_stat_input.py +++ b/source/tests/pt/test_make_stat_input.py @@ -19,7 +19,7 @@ from deepmd.utils.data import ( DataRequirementItem, ) -torch.cuda.set_device(0) + def collate_fn(batch): if isinstance(batch, dict): batch = [batch] From 47fe45ba1cf76135352e6bdaf1c5c8dc1e58d15b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 7 Jan 2025 10:45:08 +0000 Subject: [PATCH 63/82] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- source/tests/pt/test_make_stat_input.py | 1 + 1 file changed, 1 insertion(+) diff --git a/source/tests/pt/test_make_stat_input.py b/source/tests/pt/test_make_stat_input.py index 68da84df6c..40a765f647 100644 --- a/source/tests/pt/test_make_stat_input.py +++ b/source/tests/pt/test_make_stat_input.py @@ -21,6 +21,7 @@ DataRequirementItem, ) + def collate_fn(batch): if isinstance(batch, dict): batch = [batch] From b9bdee50143cda7ba575cb42c3b9bcd2279d3672 Mon Sep 17 00:00:00 2001 From: SumGuo Date: Thu, 9 Jan 2025 11:16:37 +0800 Subject: [PATCH 64/82] make truetype for more sys --- deepmd/pt/utils/dataset.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/deepmd/pt/utils/dataset.py b/deepmd/pt/utils/dataset.py index d054fc1c17..d82c3d517f 100644 --- a/deepmd/pt/utils/dataset.py +++ b/deepmd/pt/utils/dataset.py @@ -61,14 +61,17 @@ def get_frame_index(self): """ element_counts = defaultdict(lambda: {"frames": 0, "indices": []}) set_files = self._data_system.dirs + base_offset = 0 for set_file in set_files: element_data = self._data_system._load_type_mix(set_file) unique_elements = np.unique(element_data) for elem in unique_elements: frames_with_elem = np.any(element_data == elem, axis=1) row_indices = np.where(frames_with_elem)[0] + row_indices_global = np.where(frames_with_elem)[0] + base_offset element_counts[elem]["frames"] += len(row_indices) - element_counts[elem]["indices"].extend(row_indices.tolist()) + element_counts[elem]["indices"].extend(row_indices_global.tolist()) + base_offset += element_data.shape[0] element_counts = dict(element_counts) return element_counts From a30053f26f62754d5285b2c2554e6f910a76e083 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 9 Jan 2025 03:18:07 +0000 Subject: [PATCH 65/82] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- deepmd/pt/utils/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepmd/pt/utils/dataset.py b/deepmd/pt/utils/dataset.py index d82c3d517f..c598c4cfe9 100644 --- a/deepmd/pt/utils/dataset.py +++ b/deepmd/pt/utils/dataset.py @@ -71,7 +71,7 @@ def get_frame_index(self): row_indices_global = np.where(frames_with_elem)[0] + base_offset element_counts[elem]["frames"] += len(row_indices) element_counts[elem]["indices"].extend(row_indices_global.tolist()) - base_offset += element_data.shape[0] + base_offset += element_data.shape[0] element_counts = dict(element_counts) return element_counts From cfbc88a9f27a90d1e28107af4452ba7f578235c4 Mon Sep 17 00:00:00 2001 From: SumGuo Date: Thu, 9 Jan 2025 14:44:11 +0800 Subject: [PATCH 66/82] Add skip element check function to Chang bias --- deepmd/main.py | 5 +++++ deepmd/pt/entrypoints/main.py | 3 +++ deepmd/pt/utils/stat.py | 7 ++++++- .../mixed_type_data/sys.000000/set.001/box.npy | Bin 0 -> 308 bytes .../mixed_type_data/sys.000000/set.001/coord.npy | Bin 0 -> 548 bytes .../sys.000000/set.001/energy.npy | Bin 0 -> 148 bytes .../mixed_type_data/sys.000000/set.001/force.npy | Bin 0 -> 548 bytes .../sys.000000/set.001/real_atom_numbs.npy | Bin 0 -> 2408 bytes .../sys.000000/set.001/real_atom_types.npy | Bin 0 -> 408 bytes 9 files changed, 14 insertions(+), 1 deletion(-) create mode 100644 source/tests/pt/mixed_type_data/sys.000000/set.001/box.npy create mode 100644 source/tests/pt/mixed_type_data/sys.000000/set.001/coord.npy create mode 100644 source/tests/pt/mixed_type_data/sys.000000/set.001/energy.npy create mode 100644 source/tests/pt/mixed_type_data/sys.000000/set.001/force.npy create mode 100644 source/tests/pt/mixed_type_data/sys.000000/set.001/real_atom_numbs.npy create mode 100644 source/tests/pt/mixed_type_data/sys.000000/set.001/real_atom_types.npy diff --git a/deepmd/main.py b/deepmd/main.py index 097588ca0a..8eaabed44b 100644 --- a/deepmd/main.py +++ b/deepmd/main.py @@ -737,6 +737,11 @@ def main_parser() -> argparse.ArgumentParser: default=None, help="Model branch chosen for changing bias if multi-task model.", ) + parser_change_bias.add_argument( + "--skip-elementcheck", + action="store_false", + help="Enable this option to skip element checks if any error occurs while retrieving statistical data.", + ) # --version parser.add_argument( diff --git a/deepmd/pt/entrypoints/main.py b/deepmd/pt/entrypoints/main.py index fd4be73e84..d379b1382b 100644 --- a/deepmd/pt/entrypoints/main.py +++ b/deepmd/pt/entrypoints/main.py @@ -386,6 +386,7 @@ def change_bias( numb_batch: int = 0, model_branch: Optional[str] = None, output: Optional[str] = None, + elem_check_stat: bool = True, ) -> None: if input_file.endswith(".pt"): old_state_dict = torch.load( @@ -472,6 +473,7 @@ def change_bias( data_single.systems, data_single.dataloaders, nbatches, + enable_element_completion = elem_check_stat, ) updated_model = training.model_change_out_bias( model_to_change, sampled_data, _bias_adjust_mode=bias_adjust_mode @@ -555,6 +557,7 @@ def main(args: Optional[Union[list[str], argparse.Namespace]] = None) -> None: numb_batch=FLAGS.numb_batch, model_branch=FLAGS.model_branch, output=FLAGS.output, + elem_check_stat=FLAGS.skip_elementcheck, ) elif FLAGS.command == "compress": FLAGS.input = str(Path(FLAGS.input).with_suffix(".pth")) diff --git a/deepmd/pt/utils/stat.py b/deepmd/pt/utils/stat.py index 23ff776655..0209ef13f5 100644 --- a/deepmd/pt/utils/stat.py +++ b/deepmd/pt/utils/stat.py @@ -105,7 +105,12 @@ def process_with_new_frame(sys_indices, newele_counter, miss): frame_data = sys.__getitem__(frame) assert ( miss in frame_data["atype"] - ), f"Missing element '{miss}' not found in frame data." + ), ( + f"Element check failed. " + f"If you are running in 'change-bias' mode, use '--skip-elementcheck' to disable this check. " + f"If you encountered this error during model training, set 'enable_element_completion' to False " + f"in the 'training' section of your input file." + ) sys_stat_new = {} for dd in frame_data: if dd == "type": diff --git a/source/tests/pt/mixed_type_data/sys.000000/set.001/box.npy b/source/tests/pt/mixed_type_data/sys.000000/set.001/box.npy new file mode 100644 index 0000000000000000000000000000000000000000..9a75fc05919f50d033ba8bb3b7a8b3c36c2dd8d0 GIT binary patch literal 308 zcmbR27wQ`j$;eQ~P_3SlTAW;@Zl$1ZlV+l>qoAIaUsO_*m=~X4l#&V(cT3DEP6dh= zXCxM+0{I%IItrGWItsN4WCN~x5jIB#2xtIN2WO}|R_@ugA0ouSVCnS4(WIFLuEwOf x#bH~^QwNwlh<;JqoAIaUsO_*m=~X4l#&V(cT3DEP6dh= zXCxM+0{I%IItoUHnmP)#3giN=I*~T}l}jf%ENQlN+$-nlARW5L;lE?Cj z7r%1opSH*0$o^u-BS1AjE*^C#tmtyw17u6Cd*+~9+Y4sjYFg>wGq>N-q?yH0g?Xlf z*gY@D7x$0t`?yzo|H&)!9RFeDwI%C7dIm}x$lXaDMB9Qz}V&vImyp5%}?^@RQH^m50Aw>cc& z`Au?o+=tuDTiu$D2~JlX8m6^4BKuuU*4V+hf04sBiF6ddGx5K4Fp#?KP=0Nuqmf#= zL!{O=2Zy6gj`0O94)c<)I^4e2>=-fK#o-*o0*9*D4#!xg&-=wMH9AC2EO8WFywskf zTg>5Cn6=}JB@gX)KL6!#MmNlnp?8@>@zHh%p{6Xyle`ZdKxzz68avwYeQ*G&VP4?p R*t)j$z_Izg4n?aH9094|%s>DD literal 0 HcmV?d00001 diff --git a/source/tests/pt/mixed_type_data/sys.000000/set.001/energy.npy b/source/tests/pt/mixed_type_data/sys.000000/set.001/energy.npy new file mode 100644 index 0000000000000000000000000000000000000000..831de792c4b928dcb62d01fc38afb0717b4af2e7 GIT binary patch literal 148 zcmbR27wQ`j$;eQ~P_3SlTAW;@Zl$1ZlV+l>qoAIaUsO_*m=~X4l#&V(cT3DEP6dh= rXCxM+0{I%II+{8PwF(pfuJ8qx2dy4_Il%Gf&w&rI?gt;Xm>dKE^MWUe literal 0 HcmV?d00001 diff --git a/source/tests/pt/mixed_type_data/sys.000000/set.001/force.npy b/source/tests/pt/mixed_type_data/sys.000000/set.001/force.npy new file mode 100644 index 0000000000000000000000000000000000000000..3e112883296955ddfd5affcddc7af99fcc3532d6 GIT binary patch literal 548 zcmbR27wQ`j$;eQ~P_3SlTAW;@Zl$1ZlV+l>qoAIaUsO_*m=~X4l#&V(cT3DEP6dh= zXCxM+0{I%IItoUHnmP)#3giN=&EMYHur3y|X_J_?OZ8sHw&14lJ@xL|wz=Ql?Kut< z@0Omn%T-o$Yip0c&AQ?o+q2BG%qO)3+05H_V|V-XkK3N7OxZo_x$oZHuRm<-ym)Dk ze3a_m(nbZF&3f&7{3KNN$bR;<6;XR(!`aigSG8Vk4|nu>+q+7pcK$C)ZBqV;+a|5B zwtjEDckeq9=6x;u4D4jwpY0T2Ibp?r?XQjJ+9x*NMJ#*f9-O!D5#uCV^~>w5r!6qE zednoe8(wwRwtiLJzH2Y{?5TOa&ek(xid|z4yKU0z!#17Y?Dws;T4TF+k>EbAz7*>f zdpdS8SWDROT$r&}Up8ppVeS=s<>U672=>fhC3^;TZOHfHq>yM!^nP8?pcZ literal 0 HcmV?d00001 diff --git a/source/tests/pt/mixed_type_data/sys.000000/set.001/real_atom_types.npy b/source/tests/pt/mixed_type_data/sys.000000/set.001/real_atom_types.npy new file mode 100644 index 0000000000000000000000000000000000000000..1522e39c490791693b10d7c582af28bec1146a69 GIT binary patch literal 408 zcmbR27wQ`j$;eQ~P_3SlTAW;@Zl$1ZlWC!@qoAIaUsO_*m=~X4l#&V(cT3DEP6dh= zXCxM+0{I%IItu2RItsN4WCJc?1_;1Kvq9wr(P$MYpI91Z4zcQC>R|eCxeq1|Q-|(O M9;iEsrRAXN0YbbUga7~l literal 0 HcmV?d00001 From 73a20b0413bc5b870ee10d0eec8bd0235d0a7109 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 9 Jan 2025 06:45:43 +0000 Subject: [PATCH 67/82] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- deepmd/main.py | 2 +- deepmd/pt/entrypoints/main.py | 4 ++-- deepmd/pt/utils/stat.py | 12 +++++------- 3 files changed, 8 insertions(+), 10 deletions(-) diff --git a/deepmd/main.py b/deepmd/main.py index 8eaabed44b..80627f4bf0 100644 --- a/deepmd/main.py +++ b/deepmd/main.py @@ -739,7 +739,7 @@ def main_parser() -> argparse.ArgumentParser: ) parser_change_bias.add_argument( "--skip-elementcheck", - action="store_false", + action="store_false", help="Enable this option to skip element checks if any error occurs while retrieving statistical data.", ) diff --git a/deepmd/pt/entrypoints/main.py b/deepmd/pt/entrypoints/main.py index d379b1382b..c0ba8d01be 100644 --- a/deepmd/pt/entrypoints/main.py +++ b/deepmd/pt/entrypoints/main.py @@ -386,7 +386,7 @@ def change_bias( numb_batch: int = 0, model_branch: Optional[str] = None, output: Optional[str] = None, - elem_check_stat: bool = True, + elem_check_stat: bool = True, ) -> None: if input_file.endswith(".pt"): old_state_dict = torch.load( @@ -473,7 +473,7 @@ def change_bias( data_single.systems, data_single.dataloaders, nbatches, - enable_element_completion = elem_check_stat, + enable_element_completion=elem_check_stat, ) updated_model = training.model_change_out_bias( model_to_change, sampled_data, _bias_adjust_mode=bias_adjust_mode diff --git a/deepmd/pt/utils/stat.py b/deepmd/pt/utils/stat.py index 0209ef13f5..2d75665b35 100644 --- a/deepmd/pt/utils/stat.py +++ b/deepmd/pt/utils/stat.py @@ -103,13 +103,11 @@ def process_with_new_frame(sys_indices, newele_counter, miss): newele_counter += 1 if newele_counter <= min_frames_per_element_forstat: frame_data = sys.__getitem__(frame) - assert ( - miss in frame_data["atype"] - ), ( - f"Element check failed. " - f"If you are running in 'change-bias' mode, use '--skip-elementcheck' to disable this check. " - f"If you encountered this error during model training, set 'enable_element_completion' to False " - f"in the 'training' section of your input file." + assert miss in frame_data["atype"], ( + "Element check failed. " + "If you are running in 'change-bias' mode, use '--skip-elementcheck' to disable this check. " + "If you encountered this error during model training, set 'enable_element_completion' to False " + "in the 'training' section of your input file." ) sys_stat_new = {} for dd in frame_data: From 0b29b05afd9e4cfdd880c768afb4cfeb9a7343b9 Mon Sep 17 00:00:00 2001 From: SumGuo Date: Fri, 10 Jan 2025 09:41:15 +0800 Subject: [PATCH 68/82] make changebias control minframes --- deepmd/main.py | 7 +++++++ deepmd/pt/entrypoints/main.py | 3 +++ 2 files changed, 10 insertions(+) diff --git a/deepmd/main.py b/deepmd/main.py index 8eaabed44b..245a60a2a7 100644 --- a/deepmd/main.py +++ b/deepmd/main.py @@ -742,6 +742,13 @@ def main_parser() -> argparse.ArgumentParser: action="store_false", help="Enable this option to skip element checks if any error occurs while retrieving statistical data.", ) + parser_change_bias.add_argument( + "-mf", + "--min-frames", + default=10, + type=int, + help="The minimum number of frames for each element used for statistics.", + ) # --version parser.add_argument( diff --git a/deepmd/pt/entrypoints/main.py b/deepmd/pt/entrypoints/main.py index d379b1382b..c7bc460ff2 100644 --- a/deepmd/pt/entrypoints/main.py +++ b/deepmd/pt/entrypoints/main.py @@ -387,6 +387,7 @@ def change_bias( model_branch: Optional[str] = None, output: Optional[str] = None, elem_check_stat: bool = True, + min_frames : int = 10, ) -> None: if input_file.endswith(".pt"): old_state_dict = torch.load( @@ -473,6 +474,7 @@ def change_bias( data_single.systems, data_single.dataloaders, nbatches, + min_frames_per_element_forstat = min_frames, enable_element_completion = elem_check_stat, ) updated_model = training.model_change_out_bias( @@ -558,6 +560,7 @@ def main(args: Optional[Union[list[str], argparse.Namespace]] = None) -> None: model_branch=FLAGS.model_branch, output=FLAGS.output, elem_check_stat=FLAGS.skip_elementcheck, + min_frames=FLAGS.min_frames, ) elif FLAGS.command == "compress": FLAGS.input = str(Path(FLAGS.input).with_suffix(".pth")) From 0400233f155f1955f4d5900fad12f541a9934911 Mon Sep 17 00:00:00 2001 From: SumGuo Date: Fri, 10 Jan 2025 09:43:12 +0800 Subject: [PATCH 69/82] check merge --- deepmd/pt/entrypoints/main.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/deepmd/pt/entrypoints/main.py b/deepmd/pt/entrypoints/main.py index f66fa76267..9bb4dc4890 100644 --- a/deepmd/pt/entrypoints/main.py +++ b/deepmd/pt/entrypoints/main.py @@ -386,12 +386,8 @@ def change_bias( numb_batch: int = 0, model_branch: Optional[str] = None, output: Optional[str] = None, -<<<<<<< HEAD - elem_check_stat: bool = True, - min_frames : int = 10, -======= elem_check_stat: bool = True, ->>>>>>> 73a20b0413bc5b870ee10d0eec8bd0235d0a7109 + min_frames : int = 10, ) -> None: if input_file.endswith(".pt"): old_state_dict = torch.load( @@ -478,12 +474,8 @@ def change_bias( data_single.systems, data_single.dataloaders, nbatches, -<<<<<<< HEAD min_frames_per_element_forstat = min_frames, enable_element_completion = elem_check_stat, -======= - enable_element_completion=elem_check_stat, ->>>>>>> 73a20b0413bc5b870ee10d0eec8bd0235d0a7109 ) updated_model = training.model_change_out_bias( model_to_change, sampled_data, _bias_adjust_mode=bias_adjust_mode From c05ffb16cb6a8acc51c6ff896de90dc455fbb942 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 10 Jan 2025 01:44:37 +0000 Subject: [PATCH 70/82] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- deepmd/pt/entrypoints/main.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/deepmd/pt/entrypoints/main.py b/deepmd/pt/entrypoints/main.py index 9bb4dc4890..6edb2e90d8 100644 --- a/deepmd/pt/entrypoints/main.py +++ b/deepmd/pt/entrypoints/main.py @@ -387,7 +387,7 @@ def change_bias( model_branch: Optional[str] = None, output: Optional[str] = None, elem_check_stat: bool = True, - min_frames : int = 10, + min_frames: int = 10, ) -> None: if input_file.endswith(".pt"): old_state_dict = torch.load( @@ -474,8 +474,8 @@ def change_bias( data_single.systems, data_single.dataloaders, nbatches, - min_frames_per_element_forstat = min_frames, - enable_element_completion = elem_check_stat, + min_frames_per_element_forstat=min_frames, + enable_element_completion=elem_check_stat, ) updated_model = training.model_change_out_bias( model_to_change, sampled_data, _bias_adjust_mode=bias_adjust_mode From 139f0378420def6f9d403326df64562609de71bc Mon Sep 17 00:00:00 2001 From: SumGuo Date: Fri, 10 Jan 2025 11:31:18 +0800 Subject: [PATCH 71/82] improve ut with all frames --- source/tests/pt/test_make_stat_input.py | 71 +++++++++++++++++++------ 1 file changed, 55 insertions(+), 16 deletions(-) diff --git a/source/tests/pt/test_make_stat_input.py b/source/tests/pt/test_make_stat_input.py index 40a765f647..c3fcc5ebc1 100644 --- a/source/tests/pt/test_make_stat_input.py +++ b/source/tests/pt/test_make_stat_input.py @@ -23,17 +23,25 @@ def collate_fn(batch): + if isinstance(batch, dict): batch = [batch] - collated_batch = {} + + out = {} for key in batch[0].keys(): - data_list = [d[key] for d in batch] - if isinstance(data_list[0], np.ndarray): - data_np = np.stack(data_list) - collated_batch[key] = torch.from_numpy(data_np) + items = [sample[key] for sample in batch] + + if isinstance(items[0], torch.Tensor): + out[key] = torch.stack(items, dim=0) + elif isinstance(items[0], np.ndarray): + out[key] = torch.from_numpy(np.stack(items, axis=0)) else: - collated_batch[key] = torch.tensor(data_list) - return collated_batch + try: + out[key] = torch.tensor(items) + except Exception: + out[key] = items + + return out class TestMakeStatInput(unittest.TestCase): @@ -52,19 +60,10 @@ def setUpClass(cls): ] cls.datasets.add_data_requirement(data_requirements) cls.datasets = [cls.datasets] - weights_tensor = torch.tensor( - [0.1] * len(cls.datasets), dtype=torch.float64, device="cpu" - ) - sampler = torch.utils.data.WeightedRandomSampler( - weights_tensor, - num_samples=len(cls.datasets), - replacement=True, - ) cls.dataloaders = [] for dataset in cls.datasets: dataloader = DataLoader( dataset, - sampler=sampler, batch_size=1, num_workers=0, drop_last=False, @@ -129,6 +128,7 @@ def test_bias(self): min_frames_per_element_forstat=1, enable_element_completion=True, ) + bias_ori, _ = compute_output_stats(lst_ori, ntypes=57) bias_all, _ = compute_output_stats(lst_all, ntypes=57) energy_ori = np.array(bias_ori.get("energy").cpu()).flatten() @@ -148,3 +148,42 @@ def test_bias(self): f"Index {i}: energy_ori={e_ori}, energy_all={e_all}, " f"relative difference {rel_diff:.2%} is too large", ) + def test_with_nomissing(self): + lst_ori = make_stat_input( + datasets=self.datasets, + dataloaders=self.dataloaders, + nbatches=10, + min_frames_per_element_forstat=1, + enable_element_completion=False, + ) + for dct in lst_ori: + for key in ["find_box", "find_coord", "find_numb_copy", "find_energy"]: + if key in dct: + val = dct[key] + if val.numel() > 1: + dct[key] = val[0] + lst_new = make_stat_input( + datasets=self.datasets, + dataloaders=self.dataloaders, + nbatches=10, + min_frames_per_element_forstat=1, + enable_element_completion=True, + ) + for dct in lst_new: + for key in ["find_box", "find_coord", "find_numb_copy", "find_energy"]: + if key in dct: + val = dct[key] + if val.numel() > 1: + dct[key] = val[0] + bias_ori, _ = compute_output_stats(lst_ori, ntypes=57) + bias_new, _ = compute_output_stats(lst_new, ntypes=57) + energy_ori = np.array(bias_ori.get("energy").cpu()).flatten() + energy_new = np.array(bias_new.get("energy").cpu()).flatten() + self.assertTrue( + np.array_equal(energy_ori, energy_new), + msg=f"energy_ori and energy_new are not exactly the same!\n" + f"energy_ori = {energy_ori}\nenergy_new = {energy_new}" + ) + +if __name__ == "__main__": + unittest.main() From 5e826bf8a23180243bd31eaefc2e920388063d43 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 10 Jan 2025 03:32:43 +0000 Subject: [PATCH 72/82] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- source/tests/pt/test_make_stat_input.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/source/tests/pt/test_make_stat_input.py b/source/tests/pt/test_make_stat_input.py index c3fcc5ebc1..e10c47945e 100644 --- a/source/tests/pt/test_make_stat_input.py +++ b/source/tests/pt/test_make_stat_input.py @@ -23,7 +23,6 @@ def collate_fn(batch): - if isinstance(batch, dict): batch = [batch] @@ -128,7 +127,7 @@ def test_bias(self): min_frames_per_element_forstat=1, enable_element_completion=True, ) - + bias_ori, _ = compute_output_stats(lst_ori, ntypes=57) bias_all, _ = compute_output_stats(lst_all, ntypes=57) energy_ori = np.array(bias_ori.get("energy").cpu()).flatten() @@ -148,6 +147,7 @@ def test_bias(self): f"Index {i}: energy_ori={e_ori}, energy_all={e_all}, " f"relative difference {rel_diff:.2%} is too large", ) + def test_with_nomissing(self): lst_ori = make_stat_input( datasets=self.datasets, @@ -182,8 +182,9 @@ def test_with_nomissing(self): self.assertTrue( np.array_equal(energy_ori, energy_new), msg=f"energy_ori and energy_new are not exactly the same!\n" - f"energy_ori = {energy_ori}\nenergy_new = {energy_new}" + f"energy_ori = {energy_ori}\nenergy_new = {energy_new}", ) + if __name__ == "__main__": unittest.main() From edf1d91e3c36d80ecf57ccbe38fdfc963dc87bb5 Mon Sep 17 00:00:00 2001 From: SumGuo Date: Fri, 10 Jan 2025 11:33:54 +0800 Subject: [PATCH 73/82] check ut --- source/tests/pt/test_make_stat_input.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/source/tests/pt/test_make_stat_input.py b/source/tests/pt/test_make_stat_input.py index c3fcc5ebc1..b5d8c2ca07 100644 --- a/source/tests/pt/test_make_stat_input.py +++ b/source/tests/pt/test_make_stat_input.py @@ -183,7 +183,4 @@ def test_with_nomissing(self): np.array_equal(energy_ori, energy_new), msg=f"energy_ori and energy_new are not exactly the same!\n" f"energy_ori = {energy_ori}\nenergy_new = {energy_new}" - ) - -if __name__ == "__main__": - unittest.main() + ) \ No newline at end of file From eb9f068db6e22da349400bb8ed8c4b9cb220e010 Mon Sep 17 00:00:00 2001 From: SumGuo Date: Fri, 10 Jan 2025 11:35:00 +0800 Subject: [PATCH 74/82] check --- source/tests/pt/test_make_stat_input.py | 15 +++------------ 1 file changed, 3 insertions(+), 12 deletions(-) diff --git a/source/tests/pt/test_make_stat_input.py b/source/tests/pt/test_make_stat_input.py index ef7453fd9a..b5d8c2ca07 100644 --- a/source/tests/pt/test_make_stat_input.py +++ b/source/tests/pt/test_make_stat_input.py @@ -23,6 +23,7 @@ def collate_fn(batch): + if isinstance(batch, dict): batch = [batch] @@ -127,7 +128,7 @@ def test_bias(self): min_frames_per_element_forstat=1, enable_element_completion=True, ) - + bias_ori, _ = compute_output_stats(lst_ori, ntypes=57) bias_all, _ = compute_output_stats(lst_all, ntypes=57) energy_ori = np.array(bias_ori.get("energy").cpu()).flatten() @@ -147,7 +148,6 @@ def test_bias(self): f"Index {i}: energy_ori={e_ori}, energy_all={e_all}, " f"relative difference {rel_diff:.2%} is too large", ) - def test_with_nomissing(self): lst_ori = make_stat_input( datasets=self.datasets, @@ -182,14 +182,5 @@ def test_with_nomissing(self): self.assertTrue( np.array_equal(energy_ori, energy_new), msg=f"energy_ori and energy_new are not exactly the same!\n" -<<<<<<< HEAD f"energy_ori = {energy_ori}\nenergy_new = {energy_new}" - ) -======= - f"energy_ori = {energy_ori}\nenergy_new = {energy_new}", - ) - - -if __name__ == "__main__": - unittest.main() ->>>>>>> 5e826bf8a23180243bd31eaefc2e920388063d43 + ) \ No newline at end of file From 10ef768749e7ea3ca4bb38fd995793f131a8825a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 10 Jan 2025 03:36:24 +0000 Subject: [PATCH 75/82] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- source/tests/pt/test_make_stat_input.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/source/tests/pt/test_make_stat_input.py b/source/tests/pt/test_make_stat_input.py index b5d8c2ca07..bdc450512f 100644 --- a/source/tests/pt/test_make_stat_input.py +++ b/source/tests/pt/test_make_stat_input.py @@ -23,7 +23,6 @@ def collate_fn(batch): - if isinstance(batch, dict): batch = [batch] @@ -128,7 +127,7 @@ def test_bias(self): min_frames_per_element_forstat=1, enable_element_completion=True, ) - + bias_ori, _ = compute_output_stats(lst_ori, ntypes=57) bias_all, _ = compute_output_stats(lst_all, ntypes=57) energy_ori = np.array(bias_ori.get("energy").cpu()).flatten() @@ -148,6 +147,7 @@ def test_bias(self): f"Index {i}: energy_ori={e_ori}, energy_all={e_all}, " f"relative difference {rel_diff:.2%} is too large", ) + def test_with_nomissing(self): lst_ori = make_stat_input( datasets=self.datasets, @@ -182,5 +182,5 @@ def test_with_nomissing(self): self.assertTrue( np.array_equal(energy_ori, energy_new), msg=f"energy_ori and energy_new are not exactly the same!\n" - f"energy_ori = {energy_ori}\nenergy_new = {energy_new}" - ) \ No newline at end of file + f"energy_ori = {energy_ori}\nenergy_new = {energy_new}", + ) From c2dc7ef50a84d6b7979ab4d38ae5fa1f48c2bdea Mon Sep 17 00:00:00 2001 From: SumGuo Date: Fri, 10 Jan 2025 13:44:15 +0800 Subject: [PATCH 76/82] check skip logic and def name --- deepmd/main.py | 2 +- deepmd/pt/entrypoints/main.py | 8 ++++---- deepmd/pt/utils/dataset.py | 2 +- deepmd/pt/utils/stat.py | 2 +- deepmd/utils/argcheck.py | 10 ++++++++-- 5 files changed, 15 insertions(+), 9 deletions(-) diff --git a/deepmd/main.py b/deepmd/main.py index 136cf70f8e..0e024b5011 100644 --- a/deepmd/main.py +++ b/deepmd/main.py @@ -739,7 +739,7 @@ def main_parser() -> argparse.ArgumentParser: ) parser_change_bias.add_argument( "--skip-elementcheck", - action="store_false", + action="store_true", help="Enable this option to skip element checks if any error occurs while retrieving statistical data.", ) parser_change_bias.add_argument( diff --git a/deepmd/pt/entrypoints/main.py b/deepmd/pt/entrypoints/main.py index 6edb2e90d8..0fac599847 100644 --- a/deepmd/pt/entrypoints/main.py +++ b/deepmd/pt/entrypoints/main.py @@ -386,7 +386,7 @@ def change_bias( numb_batch: int = 0, model_branch: Optional[str] = None, output: Optional[str] = None, - elem_check_stat: bool = True, + skip_elem_check: bool = True, min_frames: int = 10, ) -> None: if input_file.endswith(".pt"): @@ -474,8 +474,8 @@ def change_bias( data_single.systems, data_single.dataloaders, nbatches, - min_frames_per_element_forstat=min_frames, - enable_element_completion=elem_check_stat, + min_frames_per_element_forstat = min_frames, + enable_element_completion = not skip_elem_check, ) updated_model = training.model_change_out_bias( model_to_change, sampled_data, _bias_adjust_mode=bias_adjust_mode @@ -559,7 +559,7 @@ def main(args: Optional[Union[list[str], argparse.Namespace]] = None) -> None: numb_batch=FLAGS.numb_batch, model_branch=FLAGS.model_branch, output=FLAGS.output, - elem_check_stat=FLAGS.skip_elementcheck, + skip_elem_check=FLAGS.skip_elementcheck, min_frames=FLAGS.min_frames, ) elif FLAGS.command == "compress": diff --git a/deepmd/pt/utils/dataset.py b/deepmd/pt/utils/dataset.py index c598c4cfe9..267619b69d 100644 --- a/deepmd/pt/utils/dataset.py +++ b/deepmd/pt/utils/dataset.py @@ -43,7 +43,7 @@ def __getitem__(self, index): b_data["natoms"] = self._natoms_vec return b_data - def get_frame_index(self): + def get_frame_index_for_elements(self): """ Get the frame index and the number of frames with all the elements in the system. This function is only used in the mixed type. diff --git a/deepmd/pt/utils/stat.py b/deepmd/pt/utils/stat.py index 2d75665b35..36a77c198c 100644 --- a/deepmd/pt/utils/stat.py +++ b/deepmd/pt/utils/stat.py @@ -160,7 +160,7 @@ def finalize_stats(sys_stat): # get frame index if datasets[0].mixed_type and enable_element_completion: - element_counts = dataset.get_frame_index() + element_counts = dataset.get_frame_index_for_elements() for elem, data in element_counts.items(): indices = data["indices"] count = data["frames"] diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 3a4091cd8a..79e763cf4e 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -2787,6 +2787,12 @@ def training_args( "If the gradient norm exceeds this value, it will be clipped to this limit. " "No gradient clipping will occur if set to 0." ) + doc_min_frames_per_element_forstat = ( + "The minimum number of frames per element used for statistics when using the mixed type." + ) + doc_enable_element_completion = ( + "Whether to check elements when using the mixed type" + ) doc_stat_file = ( "The file path for saving the data statistics results. " "If set, the results will be saved and directly loaded during the next training session, " @@ -2898,14 +2904,14 @@ def training_args( int, default=10, optional=True, - doc="The minimum number of frames per element used for statistics when using the mixed type.", + doc=doc_only_pt_supported + doc_min_frames_per_element_forstat, ), Argument( "enable_element_completion", bool, optional=True, default=True, - doc="Whether to check elements when using the mixed type", + doc=doc_only_pt_supported + doc_enable_element_completion, ), ] variants = [ From 87631657e2b5f060d1a98a5881f209b67306e740 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 10 Jan 2025 05:45:41 +0000 Subject: [PATCH 77/82] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- deepmd/pt/entrypoints/main.py | 4 ++-- deepmd/utils/argcheck.py | 4 +--- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/deepmd/pt/entrypoints/main.py b/deepmd/pt/entrypoints/main.py index 0fac599847..5a0fd435f7 100644 --- a/deepmd/pt/entrypoints/main.py +++ b/deepmd/pt/entrypoints/main.py @@ -474,8 +474,8 @@ def change_bias( data_single.systems, data_single.dataloaders, nbatches, - min_frames_per_element_forstat = min_frames, - enable_element_completion = not skip_elem_check, + min_frames_per_element_forstat=min_frames, + enable_element_completion=not skip_elem_check, ) updated_model = training.model_change_out_bias( model_to_change, sampled_data, _bias_adjust_mode=bias_adjust_mode diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 79e763cf4e..122759541c 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -2787,9 +2787,7 @@ def training_args( "If the gradient norm exceeds this value, it will be clipped to this limit. " "No gradient clipping will occur if set to 0." ) - doc_min_frames_per_element_forstat = ( - "The minimum number of frames per element used for statistics when using the mixed type." - ) + doc_min_frames_per_element_forstat = "The minimum number of frames per element used for statistics when using the mixed type." doc_enable_element_completion = ( "Whether to check elements when using the mixed type" ) From d5596bf9130c974e87a2e4d018be1153c437a50a Mon Sep 17 00:00:00 2001 From: SumGuo Date: Fri, 10 Jan 2025 14:42:29 +0800 Subject: [PATCH 78/82] improve warning readable --- deepmd/pt/utils/dataset.py | 7 ++++++- deepmd/pt/utils/stat.py | 11 ++++++++--- deepmd/utils/data.py | 14 ++++++++++++++ 3 files changed, 28 insertions(+), 4 deletions(-) diff --git a/deepmd/pt/utils/dataset.py b/deepmd/pt/utils/dataset.py index 267619b69d..e1deaaed8b 100644 --- a/deepmd/pt/utils/dataset.py +++ b/deepmd/pt/utils/dataset.py @@ -62,9 +62,14 @@ def get_frame_index_for_elements(self): element_counts = defaultdict(lambda: {"frames": 0, "indices": []}) set_files = self._data_system.dirs base_offset = 0 + global_type_name = {} for set_file in set_files: element_data = self._data_system._load_type_mix(set_file) unique_elements = np.unique(element_data) + type_name = self._data_system.build_reidx_to_name_map(element_data,set_file) + for new_idx, elem_name in type_name.items(): + if new_idx not in global_type_name: + global_type_name[new_idx] = elem_name for elem in unique_elements: frames_with_elem = np.any(element_data == elem, axis=1) row_indices = np.where(frames_with_elem)[0] @@ -73,7 +78,7 @@ def get_frame_index_for_elements(self): element_counts[elem]["indices"].extend(row_indices_global.tolist()) base_offset += element_data.shape[0] element_counts = dict(element_counts) - return element_counts + return element_counts, global_type_name def add_data_requirement(self, data_requirement: list[DataRequirementItem]) -> None: """Add data requirement for this data system.""" diff --git a/deepmd/pt/utils/stat.py b/deepmd/pt/utils/stat.py index 36a77c198c..7c0a8aa265 100644 --- a/deepmd/pt/utils/stat.py +++ b/deepmd/pt/utils/stat.py @@ -60,6 +60,7 @@ def make_stat_input( log.info(f"Packing data for statistics from {len(datasets)} systems") total_element_types = set() global_element_counts = {} + global_type_name = {} collect_ele = defaultdict(int) if datasets[0].mixed_type: if enable_element_completion: @@ -160,7 +161,10 @@ def finalize_stats(sys_stat): # get frame index if datasets[0].mixed_type and enable_element_completion: - element_counts = dataset.get_frame_index_for_elements() + element_counts, type_map = dataset.get_frame_index_for_elements() + for new_idx, elem_name in type_name.items(): + if new_idx not in global_type_name: + global_type_name[new_idx] = elem_name for elem, data in element_counts.items(): indices = data["indices"] count = data["frames"] @@ -195,10 +199,11 @@ def finalize_stats(sys_stat): if datasets[0].mixed_type and enable_element_completion: for elem, data in global_element_counts.items(): indices_count = data["count"] + element_name = global_type_name.get(elem, f"") if indices_count < min_frames_per_element_forstat: log.warning( - f"The number of frames in your datasets with element {elem} is {indices_count}, " - f"which is less than the required {min_frames_per_element_forstat}" + f"The number of frames in your datasets with element {element_name} is {indices_count}, " + f"which is less than the set {min_frames_per_element_forstat}" ) collect_elements = collect_ele.keys() missing_elements = total_element_types - collect_elements diff --git a/deepmd/utils/data.py b/deepmd/utils/data.py index fa01452bac..f5b9397df8 100644 --- a/deepmd/utils/data.py +++ b/deepmd/utils/data.py @@ -706,6 +706,20 @@ def _load_type_mix(self, set_name: DPPath): real_type = atom_type_mix_ return real_type + def build_reidx_to_name_map(self,typemix, set_name: DPPath): + type_map = self.type_map + type_path = set_name / "real_atom_types.npy" + real_type = type_path.load_numpy().astype(np.int32).reshape([-1, self.natoms]) + type_map_array = np.array(type_map, dtype=object) + reidx_to_name = {} + N, M = real_type.shape + for i in range(N): + for j in range(M): + old_val = int(real_type[i, j]) + new_val = int(typemix[i, j]) + reidx_to_name[new_val] = type_map_array[old_val] + return reidx_to_name + def _make_idx_map(self, atom_type): natoms = atom_type.shape[0] idx = np.arange(natoms, dtype=np.int64) From 9f389ad951a5043a0e6dde52f007ea68da8f2e2e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 10 Jan 2025 06:44:23 +0000 Subject: [PATCH 79/82] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- deepmd/pt/utils/dataset.py | 6 ++++-- deepmd/utils/data.py | 6 +++--- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/deepmd/pt/utils/dataset.py b/deepmd/pt/utils/dataset.py index e1deaaed8b..590c64110d 100644 --- a/deepmd/pt/utils/dataset.py +++ b/deepmd/pt/utils/dataset.py @@ -62,11 +62,13 @@ def get_frame_index_for_elements(self): element_counts = defaultdict(lambda: {"frames": 0, "indices": []}) set_files = self._data_system.dirs base_offset = 0 - global_type_name = {} + global_type_name = {} for set_file in set_files: element_data = self._data_system._load_type_mix(set_file) unique_elements = np.unique(element_data) - type_name = self._data_system.build_reidx_to_name_map(element_data,set_file) + type_name = self._data_system.build_reidx_to_name_map( + element_data, set_file + ) for new_idx, elem_name in type_name.items(): if new_idx not in global_type_name: global_type_name[new_idx] = elem_name diff --git a/deepmd/utils/data.py b/deepmd/utils/data.py index f5b9397df8..f2ad7061a9 100644 --- a/deepmd/utils/data.py +++ b/deepmd/utils/data.py @@ -706,7 +706,7 @@ def _load_type_mix(self, set_name: DPPath): real_type = atom_type_mix_ return real_type - def build_reidx_to_name_map(self,typemix, set_name: DPPath): + def build_reidx_to_name_map(self, typemix, set_name: DPPath): type_map = self.type_map type_path = set_name / "real_atom_types.npy" real_type = type_path.load_numpy().astype(np.int32).reshape([-1, self.natoms]) @@ -715,8 +715,8 @@ def build_reidx_to_name_map(self,typemix, set_name: DPPath): N, M = real_type.shape for i in range(N): for j in range(M): - old_val = int(real_type[i, j]) - new_val = int(typemix[i, j]) + old_val = int(real_type[i, j]) + new_val = int(typemix[i, j]) reidx_to_name[new_val] = type_map_array[old_val] return reidx_to_name From 0c76ad945c8b5e684cdb38a81962d424301dc466 Mon Sep 17 00:00:00 2001 From: SumGuo Date: Fri, 10 Jan 2025 15:23:23 +0800 Subject: [PATCH 80/82] check args --- deepmd/pt/utils/dataset.py | 3 +++ deepmd/utils/data.py | 4 ++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/deepmd/pt/utils/dataset.py b/deepmd/pt/utils/dataset.py index 590c64110d..4153d12519 100644 --- a/deepmd/pt/utils/dataset.py +++ b/deepmd/pt/utils/dataset.py @@ -46,6 +46,7 @@ def __getitem__(self, index): def get_frame_index_for_elements(self): """ Get the frame index and the number of frames with all the elements in the system. + Map the remapped atom_type_mix back to their element names in type_map, This function is only used in the mixed type. Returns @@ -58,6 +59,8 @@ def get_frame_index_for_elements(self): The total number of frames in which the element appears. - "indices": list of int A list of row indices where the element is found in the dataset. + global_type_name : dict + The key is the element index and the value is the element name. """ element_counts = defaultdict(lambda: {"frames": 0, "indices": []}) set_files = self._data_system.dirs diff --git a/deepmd/utils/data.py b/deepmd/utils/data.py index f2ad7061a9..5fe92c617b 100644 --- a/deepmd/utils/data.py +++ b/deepmd/utils/data.py @@ -716,8 +716,8 @@ def build_reidx_to_name_map(self, typemix, set_name: DPPath): for i in range(N): for j in range(M): old_val = int(real_type[i, j]) - new_val = int(typemix[i, j]) - reidx_to_name[new_val] = type_map_array[old_val] + re_val = int(typemix[i, j]) + reidx_to_name[re_val] = type_map_array[old_val] return reidx_to_name def _make_idx_map(self, atom_type): From 58647f38f5b3228d42bea418ce7426df7da10542 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 10 Jan 2025 07:24:52 +0000 Subject: [PATCH 81/82] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- deepmd/pt/utils/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepmd/pt/utils/dataset.py b/deepmd/pt/utils/dataset.py index 4153d12519..a59f46be3f 100644 --- a/deepmd/pt/utils/dataset.py +++ b/deepmd/pt/utils/dataset.py @@ -46,7 +46,7 @@ def __getitem__(self, index): def get_frame_index_for_elements(self): """ Get the frame index and the number of frames with all the elements in the system. - Map the remapped atom_type_mix back to their element names in type_map, + Map the remapped atom_type_mix back to their element names in type_map, This function is only used in the mixed type. Returns From 4ce9cfba9493f7dc256091814650ac4d38d971da Mon Sep 17 00:00:00 2001 From: SumGuo Date: Fri, 10 Jan 2025 15:49:28 +0800 Subject: [PATCH 82/82] check stat.py --- deepmd/pt/utils/stat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepmd/pt/utils/stat.py b/deepmd/pt/utils/stat.py index 7c0a8aa265..7657e84d75 100644 --- a/deepmd/pt/utils/stat.py +++ b/deepmd/pt/utils/stat.py @@ -161,7 +161,7 @@ def finalize_stats(sys_stat): # get frame index if datasets[0].mixed_type and enable_element_completion: - element_counts, type_map = dataset.get_frame_index_for_elements() + element_counts, type_name = dataset.get_frame_index_for_elements() for new_idx, elem_name in type_name.items(): if new_idx not in global_type_name: global_type_name[new_idx] = elem_name