From c6cdfc39c4251fa3a094c52a8bb9ff274b28bc6d Mon Sep 17 00:00:00 2001 From: stefaanhess Date: Fri, 16 Aug 2024 14:40:34 +0200 Subject: [PATCH] make estimate_atomrefs optional, docstrings, comments, assertions --- src/schnetpack/data/atoms.py | 4 +++- src/schnetpack/data/datamodule.py | 1 - src/schnetpack/data/stats.py | 25 ++++++++++++++++++------- src/schnetpack/transform/atomistic.py | 20 ++++++++++++++++---- 4 files changed, 37 insertions(+), 13 deletions(-) diff --git a/src/schnetpack/data/atoms.py b/src/schnetpack/data/atoms.py index df3aacf4..75094337 100644 --- a/src/schnetpack/data/atoms.py +++ b/src/schnetpack/data/atoms.py @@ -359,7 +359,9 @@ def _get_properties( properties[structure.cell] = ( torch.tensor(row["cell"][None].copy()) * self.distance_conversion ) - properties[structure.pbc] = torch.tensor(row["pbc"].tolist()) # todo: check this RuntimeError: Could not infer dtype of numpy.bool + properties[structure.pbc] = torch.tensor( + row["pbc"].tolist() + ) # todo: check this RuntimeError: Could not infer dtype of numpy.bool return properties diff --git a/src/schnetpack/data/datamodule.py b/src/schnetpack/data/datamodule.py index b941059f..28eeb50a 100644 --- a/src/schnetpack/data/datamodule.py +++ b/src/schnetpack/data/datamodule.py @@ -363,7 +363,6 @@ def get_stats( def get_atomrefs( self, property: str, divide_by_atoms: bool - ) -> Tuple[torch.Tensor, torch.Tensor]: key = (property, divide_by_atoms) if key in self._stats: diff --git a/src/schnetpack/data/stats.py b/src/schnetpack/data/stats.py index 6d6022b2..bb057e0c 100644 --- a/src/schnetpack/data/stats.py +++ b/src/schnetpack/data/stats.py @@ -23,13 +23,13 @@ def calculate_stats( .. [h1] https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance Args: - dataset: atoms data set + dataloader: data loader divide_by_atoms: dict from property name to bool: If True, divide property by number of atoms before calculating statistics. atomref: reference values for single atoms to be removed before calculating stats - Returns: + Mean and standard deviation over all samples """ property_names = list(divide_by_atoms.keys()) @@ -41,7 +41,7 @@ def calculate_stats( mean = torch.zeros_like(norm_mask) M2 = torch.zeros_like(norm_mask) - for props in tqdm(dataloader): + for props in tqdm(dataloader, "calculating statistics"): sample_values = [] for p in property_names: val = props[p][None, :] @@ -78,16 +78,27 @@ def calculate_stats( return stats -def estimate_atomrefs(loader, divide_by_atoms, z_max=100): +def estimate_atomrefs(dataloader, divide_by_atoms, z_max=100): + """ + Uses linear regression to estimate the elementwise biases (atomrefs). + + Args: + dataloader: data loader + divide_by_atoms: dict from property name to bool: + If True, divide property by number of atoms before calculating statistics. + + Returns: + Elementwise bias estimates over all samples + + """ property_names = list(divide_by_atoms.keys()) - n_data = len(loader.dataset) + n_data = len(dataloader.dataset) all_properties = {pname: torch.zeros(n_data) for pname in property_names} all_atom_types = torch.zeros((n_data, z_max)) data_counter = 0 - # loop over all batches - for batch in tqdm(loader): + for batch in tqdm(dataloader, "estimating atomrefs"): # load data idx_m = batch[properties.idx_m] atomic_numbers = batch[properties.Z] diff --git a/src/schnetpack/transform/atomistic.py b/src/schnetpack/transform/atomistic.py index 285774bb..9c794470 100644 --- a/src/schnetpack/transform/atomistic.py +++ b/src/schnetpack/transform/atomistic.py @@ -75,6 +75,7 @@ def __init__( zmax: int = 100, atomrefs: torch.Tensor = None, property_mean: torch.Tensor = None, + estimate_atomref: bool = False, ): """ Args: @@ -92,10 +93,14 @@ def __init__( self.remove_mean = remove_mean self.remove_atomrefs = remove_atomrefs self.is_extensive = is_extensive + self.estimate_atomref = estimate_atomref assert ( remove_atomrefs or remove_mean ), "You should set at least one of `remove_mean` and `remove_atomrefs` to true!" + assert ( + not (estimate_atomref and atomrefs is not None) + ), "You can not set `atomrefs` and use `estimate_atomrefs=True!`" if atomrefs is not None: self._atomrefs_initialized = True @@ -119,9 +124,10 @@ def datamodule(self, _datamodule): Sets mean and atomref automatically when using PyTorchLightning integration. """ if self.remove_atomrefs and not self._atomrefs_initialized: - atrefs = _datamodule.train_dataset.atomrefs - if atrefs[self._property] is None: + if self.estimate_atomref: atrefs = _datamodule.get_atomrefs(self._property, self.is_extensive) + else: + atrefs = _datamodule.train_dataset.atomrefs self.atomref = atrefs[self._property].detach() if self.remove_mean and not self._mean_initialized: @@ -228,6 +234,7 @@ def __init__( zmax: int = 100, atomrefs: torch.Tensor = None, property_mean: torch.Tensor = None, + estimate_atomref: bool = False, ): """ Args: @@ -246,10 +253,14 @@ def __init__( self.add_atomrefs = add_atomrefs self.is_extensive = is_extensive self._aggregation = "sum" if self.is_extensive else "mean" + self.estimate_atomref = estimate_atomref assert ( add_mean or add_atomrefs ), "You should set at least one of `add_mean` and `add_atomrefs` to true!" + assert ( + not (estimate_atomref and atomrefs is not None) + ), "You can not set `atomrefs` and use `estimate_atomrefs=True!`" if atomrefs is not None: self._atomrefs_initialized = True @@ -268,9 +279,10 @@ def __init__( def datamodule(self, _datamodule): if self.add_atomrefs and not self._atomrefs_initialized: - atrefs = _datamodule.train_dataset.atomrefs - if atrefs[self._property] is None: + if self.estimate_atomref: atrefs = _datamodule.get_atomrefs(self._property, self.is_extensive) + else: + atrefs = _datamodule.train_dataset.atomrefs self.atomref = atrefs[self._property].detach() if self.add_mean and not self._mean_initialized: