diff --git a/src/schnetpack/data/datamodule.py b/src/schnetpack/data/datamodule.py index 4484dc5e8..3ec83bfe1 100644 --- a/src/schnetpack/data/datamodule.py +++ b/src/schnetpack/data/datamodule.py @@ -16,6 +16,7 @@ BaseAtomsData, AtomsLoader, calculate_stats, + estimate_atomrefs, SplittingStrategy, RandomSplit, ) @@ -127,6 +128,7 @@ def __init__( self.property_units = property_units self.distance_unit = distance_unit self._stats = {} + self._atomrefs = {} self._is_setup = False self.data_workdir = data_workdir self.cleanup_workdir_stage = cleanup_workdir_stage @@ -359,6 +361,20 @@ def get_stats( self._stats[key] = stats return stats + def get_atomrefs( + self, property: str, is_extensive: bool + ) -> Tuple[torch.Tensor, torch.Tensor]: + key = (property, is_extensive) + if key in self._atomrefs: + return {property: self._atomrefs[key]} + + atomrefs = estimate_atomrefs( + self.train_dataloader(), + is_extensive={property: is_extensive}, + )[property] + self._atomrefs[key] = atomrefs + return {property: atomrefs} + @property def train_dataset(self) -> BaseAtomsData: return self._train_dataset diff --git a/src/schnetpack/data/stats.py b/src/schnetpack/data/stats.py index 22327d876..7276c8149 100644 --- a/src/schnetpack/data/stats.py +++ b/src/schnetpack/data/stats.py @@ -3,10 +3,10 @@ import torch from tqdm import tqdm -import schnetpack.properties as structure +import schnetpack.properties as properties from schnetpack.data import AtomsLoader -__all__ = ["calculate_stats"] +__all__ = ["calculate_stats", "estimate_atomrefs"] def calculate_stats( @@ -23,13 +23,13 @@ def calculate_stats( .. [h1] https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance Args: - dataset: atoms data set + dataloader: data loader divide_by_atoms: dict from property name to bool: If True, divide property by number of atoms before calculating statistics. atomref: reference values for single atoms to be removed before calculating stats - Returns: + Mean and standard deviation over all samples """ property_names = list(divide_by_atoms.keys()) @@ -41,14 +41,14 @@ def calculate_stats( mean = torch.zeros_like(norm_mask) M2 = torch.zeros_like(norm_mask) - for props in tqdm(dataloader): + for props in tqdm(dataloader, "calculating statistics"): sample_values = [] for p in property_names: val = props[p][None, :] if atomref and p in atomref.keys(): ar = atomref[p] - ar = ar[props[structure.Z]] - idx_m = props[structure.idx_m] + ar = ar[props[properties.Z]] + idx_m = props[properties.idx_m] tmp = torch.zeros((idx_m[-1] + 1,), dtype=ar.dtype, device=ar.device) v0 = tmp.index_add(0, idx_m, ar) val -= v0 @@ -59,7 +59,7 @@ def calculate_stats( batch_size = sample_values.shape[1] new_count = count + batch_size - norm = norm_mask[:, None] * props[structure.n_atoms][None, :] + ( + norm = norm_mask[:, None] * props[properties.n_atoms][None, :] + ( 1 - norm_mask[:, None] ) sample_values /= norm @@ -76,3 +76,68 @@ def calculate_stats( stddev = torch.sqrt(M2 / count) stats = {pn: (mu, std) for pn, mu, std in zip(property_names, mean, stddev)} return stats + + +def estimate_atomrefs(dataloader, is_extensive, z_max=100): + """ + Uses linear regression to estimate the elementwise biases (atomrefs). + + Args: + dataloader: data loader + is_extensive: If True, divide atom type counts by number of atoms before + calculating statistics. + + Returns: + Elementwise bias estimates over all samples + + """ + property_names = list(is_extensive.keys()) + n_data = len(dataloader.dataset) + all_properties = {pname: torch.zeros(n_data) for pname in property_names} + all_atom_types = torch.zeros((n_data, z_max)) + data_counter = 0 + + # loop over all batches + for batch in tqdm(dataloader, "estimating atomrefs"): + # load data + idx_m = batch[properties.idx_m] + atomic_numbers = batch[properties.Z] + + # get counts for atomic numbers + unique_ids = torch.unique(idx_m) + for i in unique_ids: + atomic_numbers_i = atomic_numbers[idx_m == i] + atom_types, atom_counts = torch.unique(atomic_numbers_i, return_counts=True) + # save atom counts and properties + for atom_type, atom_count in zip(atom_types, atom_counts): + all_atom_types[data_counter, atom_type] = atom_count + for pname in property_names: + property_value = batch[pname][i] + if not is_extensive[pname]: + property_value *= batch[properties.n_atoms][i] + all_properties[pname][data_counter] = property_value + data_counter += 1 + + # perform linear regression to get the elementwise energy contributions + existing_atom_types = torch.where(all_atom_types.sum(axis=0) != 0)[0] + X = torch.squeeze(all_atom_types[:, existing_atom_types]) + w = dict() + for pname in property_names: + if is_extensive[pname]: + w[pname] = torch.linalg.inv(X.T @ X) @ X.T @ all_properties[pname] + else: + w[pname] = ( + torch.linalg.inv(X.T @ X) + @ X.T + @ (all_properties[pname] / X.sum(axis=1)) + ) + + # compute energy estimates + elementwise_contributions = { + pname: torch.zeros((z_max)) for pname in property_names + } + for pname in property_names: + for atom_type, weight in zip(existing_atom_types, w[pname]): + elementwise_contributions[pname][atom_type] = weight + + return elementwise_contributions diff --git a/src/schnetpack/transform/atomistic.py b/src/schnetpack/transform/atomistic.py index b4e5cc1ff..9fd4c32c4 100644 --- a/src/schnetpack/transform/atomistic.py +++ b/src/schnetpack/transform/atomistic.py @@ -75,6 +75,7 @@ def __init__( zmax: int = 100, atomrefs: torch.Tensor = None, property_mean: torch.Tensor = None, + estimate_atomref: bool = False, ): """ Args: @@ -92,10 +93,11 @@ def __init__( self.remove_mean = remove_mean self.remove_atomrefs = remove_atomrefs self.is_extensive = is_extensive + self.estimate_atomref = estimate_atomref - assert ( - remove_atomrefs or remove_mean - ), "You should set at least one of `remove_mean` and `remove_atomrefs` to true!" + assert not ( + estimate_atomref and atomrefs is not None + ), "You can not set `atomrefs` and use `estimate_atomrefs=True!`" if atomrefs is not None: self._atomrefs_initialized = True @@ -119,7 +121,12 @@ def datamodule(self, _datamodule): Sets mean and atomref automatically when using PyTorchLightning integration. """ if self.remove_atomrefs and not self._atomrefs_initialized: - atrefs = _datamodule.train_dataset.atomrefs + if self.estimate_atomref: + atrefs = _datamodule.get_atomrefs( + property=self._property, is_extensive=self.is_extensive + ) + else: + atrefs = _datamodule.train_dataset.atomrefs self.atomref = atrefs[self._property].detach() if self.remove_mean and not self._mean_initialized: @@ -140,7 +147,10 @@ def forward( ) inputs[self._property] -= mean if self.remove_atomrefs: - inputs[self._property] -= torch.sum(self.atomref[inputs[structure.Z]]) + atomref_bias = torch.sum(self.atomref[inputs[structure.Z]]) + if not self.is_extensive: + atomref_bias /= inputs[structure.n_atoms].item() + inputs[self._property] -= atomref_bias return inputs @@ -230,6 +240,7 @@ def __init__( zmax: int = 100, atomrefs: torch.Tensor = None, property_mean: torch.Tensor = None, + estimate_atomref: bool = False, ): """ Args: @@ -247,10 +258,12 @@ def __init__( self.add_mean = add_mean self.add_atomrefs = add_atomrefs self.is_extensive = is_extensive + self._aggregation = "sum" if self.is_extensive else "mean" + self.estimate_atomref = estimate_atomref - assert ( - add_mean or add_atomrefs - ), "You should set at least one of `add_mean` and `add_atomrefs` to true!" + assert not ( + estimate_atomref and atomrefs is not None + ), "You can not set `atomrefs` and use `estimate_atomrefs=True!`" if atomrefs is not None: self._atomrefs_initialized = True @@ -267,13 +280,18 @@ def __init__( self.register_buffer("atomref", atomrefs) self.register_buffer("mean", property_mean) - def datamodule(self, value): + def datamodule(self, _datamodule): if self.add_atomrefs and not self._atomrefs_initialized: - atrefs = value.train_dataset.atomrefs + if self.estimate_atomref: + atrefs = _datamodule.get_atomrefs( + property=self._property, is_extensive=self.is_extensive + ) + else: + atrefs = _datamodule.train_dataset.atomrefs self.atomref = atrefs[self._property].detach() if self.add_mean and not self._mean_initialized: - stats = value.get_stats( + stats = _datamodule.get_stats( self._property, self.is_extensive, self.add_atomrefs ) self.mean = stats[0].detach()