Skip to content

Commit

Permalink
make estimate_atomrefs optional, docstrings, comments, assertions
Browse files Browse the repository at this point in the history
  • Loading branch information
stefaanhessmann committed Aug 16, 2024
1 parent 6a2467e commit c6cdfc3
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 13 deletions.
4 changes: 3 additions & 1 deletion src/schnetpack/data/atoms.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,9 @@ def _get_properties(
properties[structure.cell] = (
torch.tensor(row["cell"][None].copy()) * self.distance_conversion
)
properties[structure.pbc] = torch.tensor(row["pbc"].tolist()) # todo: check this RuntimeError: Could not infer dtype of numpy.bool
properties[structure.pbc] = torch.tensor(
row["pbc"].tolist()
) # todo: check this RuntimeError: Could not infer dtype of numpy.bool

return properties

Expand Down
1 change: 0 additions & 1 deletion src/schnetpack/data/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,6 @@ def get_stats(

def get_atomrefs(
self, property: str, divide_by_atoms: bool

) -> Tuple[torch.Tensor, torch.Tensor]:
key = (property, divide_by_atoms)
if key in self._stats:
Expand Down
25 changes: 18 additions & 7 deletions src/schnetpack/data/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@ def calculate_stats(
.. [h1] https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
Args:
dataset: atoms data set
dataloader: data loader
divide_by_atoms: dict from property name to bool:
If True, divide property by number of atoms before calculating statistics.
atomref: reference values for single atoms to be removed before calculating stats
Returns:
Mean and standard deviation over all samples
"""
property_names = list(divide_by_atoms.keys())
Expand All @@ -41,7 +41,7 @@ def calculate_stats(
mean = torch.zeros_like(norm_mask)
M2 = torch.zeros_like(norm_mask)

for props in tqdm(dataloader):
for props in tqdm(dataloader, "calculating statistics"):
sample_values = []
for p in property_names:
val = props[p][None, :]
Expand Down Expand Up @@ -78,16 +78,27 @@ def calculate_stats(
return stats


def estimate_atomrefs(loader, divide_by_atoms, z_max=100):
def estimate_atomrefs(dataloader, divide_by_atoms, z_max=100):
"""
Uses linear regression to estimate the elementwise biases (atomrefs).
Args:
dataloader: data loader
divide_by_atoms: dict from property name to bool:
If True, divide property by number of atoms before calculating statistics.
Returns:
Elementwise bias estimates over all samples
"""
property_names = list(divide_by_atoms.keys())
n_data = len(loader.dataset)
n_data = len(dataloader.dataset)
all_properties = {pname: torch.zeros(n_data) for pname in property_names}
all_atom_types = torch.zeros((n_data, z_max))
data_counter = 0


# loop over all batches
for batch in tqdm(loader):
for batch in tqdm(dataloader, "estimating atomrefs"):
# load data
idx_m = batch[properties.idx_m]
atomic_numbers = batch[properties.Z]
Expand Down
20 changes: 16 additions & 4 deletions src/schnetpack/transform/atomistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def __init__(
zmax: int = 100,
atomrefs: torch.Tensor = None,
property_mean: torch.Tensor = None,
estimate_atomref: bool = False,
):
"""
Args:
Expand All @@ -92,10 +93,14 @@ def __init__(
self.remove_mean = remove_mean
self.remove_atomrefs = remove_atomrefs
self.is_extensive = is_extensive
self.estimate_atomref = estimate_atomref

assert (
remove_atomrefs or remove_mean
), "You should set at least one of `remove_mean` and `remove_atomrefs` to true!"
assert (
not (estimate_atomref and atomrefs is not None)
), "You can not set `atomrefs` and use `estimate_atomrefs=True!`"

if atomrefs is not None:
self._atomrefs_initialized = True
Expand All @@ -119,9 +124,10 @@ def datamodule(self, _datamodule):
Sets mean and atomref automatically when using PyTorchLightning integration.
"""
if self.remove_atomrefs and not self._atomrefs_initialized:
atrefs = _datamodule.train_dataset.atomrefs
if atrefs[self._property] is None:
if self.estimate_atomref:
atrefs = _datamodule.get_atomrefs(self._property, self.is_extensive)
else:
atrefs = _datamodule.train_dataset.atomrefs
self.atomref = atrefs[self._property].detach()

if self.remove_mean and not self._mean_initialized:
Expand Down Expand Up @@ -228,6 +234,7 @@ def __init__(
zmax: int = 100,
atomrefs: torch.Tensor = None,
property_mean: torch.Tensor = None,
estimate_atomref: bool = False,
):
"""
Args:
Expand All @@ -246,10 +253,14 @@ def __init__(
self.add_atomrefs = add_atomrefs
self.is_extensive = is_extensive
self._aggregation = "sum" if self.is_extensive else "mean"
self.estimate_atomref = estimate_atomref

assert (
add_mean or add_atomrefs
), "You should set at least one of `add_mean` and `add_atomrefs` to true!"
assert (
not (estimate_atomref and atomrefs is not None)
), "You can not set `atomrefs` and use `estimate_atomrefs=True!`"

if atomrefs is not None:
self._atomrefs_initialized = True
Expand All @@ -268,9 +279,10 @@ def __init__(

def datamodule(self, _datamodule):
if self.add_atomrefs and not self._atomrefs_initialized:
atrefs = _datamodule.train_dataset.atomrefs
if atrefs[self._property] is None:
if self.estimate_atomref:
atrefs = _datamodule.get_atomrefs(self._property, self.is_extensive)
else:
atrefs = _datamodule.train_dataset.atomrefs
self.atomref = atrefs[self._property].detach()

if self.add_mean and not self._mean_initialized:
Expand Down

0 comments on commit c6cdfc3

Please sign in to comment.