Skip to content

Commit

Permalink
Sh/elementwise statistics (#653)
Browse files Browse the repository at this point in the history
* added: function for automatically estimating atomrefs

* make estimate_atomrefs optional, docstrings, comments, assertions

* undo temporary fix
  • Loading branch information
stefaanhessmann authored Oct 25, 2024
1 parent b70361f commit 662ea1c
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 19 deletions.
16 changes: 16 additions & 0 deletions src/schnetpack/data/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
BaseAtomsData,
AtomsLoader,
calculate_stats,
estimate_atomrefs,
SplittingStrategy,
RandomSplit,
)
Expand Down Expand Up @@ -127,6 +128,7 @@ def __init__(
self.property_units = property_units
self.distance_unit = distance_unit
self._stats = {}
self._atomrefs = {}
self._is_setup = False
self.data_workdir = data_workdir
self.cleanup_workdir_stage = cleanup_workdir_stage
Expand Down Expand Up @@ -359,6 +361,20 @@ def get_stats(
self._stats[key] = stats
return stats

def get_atomrefs(
self, property: str, is_extensive: bool
) -> Tuple[torch.Tensor, torch.Tensor]:
key = (property, is_extensive)
if key in self._atomrefs:
return {property: self._atomrefs[key]}

atomrefs = estimate_atomrefs(
self.train_dataloader(),
is_extensive={property: is_extensive},
)[property]
self._atomrefs[key] = atomrefs
return {property: atomrefs}

@property
def train_dataset(self) -> BaseAtomsData:
return self._train_dataset
Expand Down
81 changes: 73 additions & 8 deletions src/schnetpack/data/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
import torch
from tqdm import tqdm

import schnetpack.properties as structure
import schnetpack.properties as properties
from schnetpack.data import AtomsLoader

__all__ = ["calculate_stats"]
__all__ = ["calculate_stats", "estimate_atomrefs"]


def calculate_stats(
Expand All @@ -23,13 +23,13 @@ def calculate_stats(
.. [h1] https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
Args:
dataset: atoms data set
dataloader: data loader
divide_by_atoms: dict from property name to bool:
If True, divide property by number of atoms before calculating statistics.
atomref: reference values for single atoms to be removed before calculating stats
Returns:
Mean and standard deviation over all samples
"""
property_names = list(divide_by_atoms.keys())
Expand All @@ -41,14 +41,14 @@ def calculate_stats(
mean = torch.zeros_like(norm_mask)
M2 = torch.zeros_like(norm_mask)

for props in tqdm(dataloader):
for props in tqdm(dataloader, "calculating statistics"):
sample_values = []
for p in property_names:
val = props[p][None, :]
if atomref and p in atomref.keys():
ar = atomref[p]
ar = ar[props[structure.Z]]
idx_m = props[structure.idx_m]
ar = ar[props[properties.Z]]
idx_m = props[properties.idx_m]
tmp = torch.zeros((idx_m[-1] + 1,), dtype=ar.dtype, device=ar.device)
v0 = tmp.index_add(0, idx_m, ar)
val -= v0
Expand All @@ -59,7 +59,7 @@ def calculate_stats(
batch_size = sample_values.shape[1]
new_count = count + batch_size

norm = norm_mask[:, None] * props[structure.n_atoms][None, :] + (
norm = norm_mask[:, None] * props[properties.n_atoms][None, :] + (
1 - norm_mask[:, None]
)
sample_values /= norm
Expand All @@ -76,3 +76,68 @@ def calculate_stats(
stddev = torch.sqrt(M2 / count)
stats = {pn: (mu, std) for pn, mu, std in zip(property_names, mean, stddev)}
return stats


def estimate_atomrefs(dataloader, is_extensive, z_max=100):
"""
Uses linear regression to estimate the elementwise biases (atomrefs).
Args:
dataloader: data loader
is_extensive: If True, divide atom type counts by number of atoms before
calculating statistics.
Returns:
Elementwise bias estimates over all samples
"""
property_names = list(is_extensive.keys())
n_data = len(dataloader.dataset)
all_properties = {pname: torch.zeros(n_data) for pname in property_names}
all_atom_types = torch.zeros((n_data, z_max))
data_counter = 0

# loop over all batches
for batch in tqdm(dataloader, "estimating atomrefs"):
# load data
idx_m = batch[properties.idx_m]
atomic_numbers = batch[properties.Z]

# get counts for atomic numbers
unique_ids = torch.unique(idx_m)
for i in unique_ids:
atomic_numbers_i = atomic_numbers[idx_m == i]
atom_types, atom_counts = torch.unique(atomic_numbers_i, return_counts=True)
# save atom counts and properties
for atom_type, atom_count in zip(atom_types, atom_counts):
all_atom_types[data_counter, atom_type] = atom_count
for pname in property_names:
property_value = batch[pname][i]
if not is_extensive[pname]:
property_value *= batch[properties.n_atoms][i]
all_properties[pname][data_counter] = property_value
data_counter += 1

# perform linear regression to get the elementwise energy contributions
existing_atom_types = torch.where(all_atom_types.sum(axis=0) != 0)[0]
X = torch.squeeze(all_atom_types[:, existing_atom_types])
w = dict()
for pname in property_names:
if is_extensive[pname]:
w[pname] = torch.linalg.inv(X.T @ X) @ X.T @ all_properties[pname]
else:
w[pname] = (
torch.linalg.inv(X.T @ X)
@ X.T
@ (all_properties[pname] / X.sum(axis=1))
)

# compute energy estimates
elementwise_contributions = {
pname: torch.zeros((z_max)) for pname in property_names
}
for pname in property_names:
for atom_type, weight in zip(existing_atom_types, w[pname]):
elementwise_contributions[pname][atom_type] = weight

return elementwise_contributions
40 changes: 29 additions & 11 deletions src/schnetpack/transform/atomistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def __init__(
zmax: int = 100,
atomrefs: torch.Tensor = None,
property_mean: torch.Tensor = None,
estimate_atomref: bool = False,
):
"""
Args:
Expand All @@ -92,10 +93,11 @@ def __init__(
self.remove_mean = remove_mean
self.remove_atomrefs = remove_atomrefs
self.is_extensive = is_extensive
self.estimate_atomref = estimate_atomref

assert (
remove_atomrefs or remove_mean
), "You should set at least one of `remove_mean` and `remove_atomrefs` to true!"
assert not (
estimate_atomref and atomrefs is not None
), "You can not set `atomrefs` and use `estimate_atomrefs=True!`"

if atomrefs is not None:
self._atomrefs_initialized = True
Expand All @@ -119,7 +121,12 @@ def datamodule(self, _datamodule):
Sets mean and atomref automatically when using PyTorchLightning integration.
"""
if self.remove_atomrefs and not self._atomrefs_initialized:
atrefs = _datamodule.train_dataset.atomrefs
if self.estimate_atomref:
atrefs = _datamodule.get_atomrefs(
property=self._property, is_extensive=self.is_extensive
)
else:
atrefs = _datamodule.train_dataset.atomrefs
self.atomref = atrefs[self._property].detach()

if self.remove_mean and not self._mean_initialized:
Expand All @@ -140,7 +147,10 @@ def forward(
)
inputs[self._property] -= mean
if self.remove_atomrefs:
inputs[self._property] -= torch.sum(self.atomref[inputs[structure.Z]])
atomref_bias = torch.sum(self.atomref[inputs[structure.Z]])
if not self.is_extensive:
atomref_bias /= inputs[structure.n_atoms].item()
inputs[self._property] -= atomref_bias

return inputs

Expand Down Expand Up @@ -230,6 +240,7 @@ def __init__(
zmax: int = 100,
atomrefs: torch.Tensor = None,
property_mean: torch.Tensor = None,
estimate_atomref: bool = False,
):
"""
Args:
Expand All @@ -247,10 +258,12 @@ def __init__(
self.add_mean = add_mean
self.add_atomrefs = add_atomrefs
self.is_extensive = is_extensive
self._aggregation = "sum" if self.is_extensive else "mean"
self.estimate_atomref = estimate_atomref

assert (
add_mean or add_atomrefs
), "You should set at least one of `add_mean` and `add_atomrefs` to true!"
assert not (
estimate_atomref and atomrefs is not None
), "You can not set `atomrefs` and use `estimate_atomrefs=True!`"

if atomrefs is not None:
self._atomrefs_initialized = True
Expand All @@ -267,13 +280,18 @@ def __init__(
self.register_buffer("atomref", atomrefs)
self.register_buffer("mean", property_mean)

def datamodule(self, value):
def datamodule(self, _datamodule):
if self.add_atomrefs and not self._atomrefs_initialized:
atrefs = value.train_dataset.atomrefs
if self.estimate_atomref:
atrefs = _datamodule.get_atomrefs(
property=self._property, is_extensive=self.is_extensive
)
else:
atrefs = _datamodule.train_dataset.atomrefs
self.atomref = atrefs[self._property].detach()

if self.add_mean and not self._mean_initialized:
stats = value.get_stats(
stats = _datamodule.get_stats(
self._property, self.is_extensive, self.add_atomrefs
)
self.mean = stats[0].detach()
Expand Down

0 comments on commit 662ea1c

Please sign in to comment.