Skip to content

Commit

Permalink
added: function for automatically estimating atomrefs
Browse files Browse the repository at this point in the history
  • Loading branch information
stefaanhessmann committed Aug 15, 2024
1 parent 643c9a1 commit 6a2467e
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 9 deletions.
2 changes: 1 addition & 1 deletion src/schnetpack/data/atoms.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ def _get_properties(
properties[structure.cell] = (
torch.tensor(row["cell"][None].copy()) * self.distance_conversion
)
properties[structure.pbc] = torch.tensor(row["pbc"])
properties[structure.pbc] = torch.tensor(row["pbc"].tolist()) # todo: check this RuntimeError: Could not infer dtype of numpy.bool

return properties

Expand Down
17 changes: 17 additions & 0 deletions src/schnetpack/data/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
BaseAtomsData,
AtomsLoader,
calculate_stats,
estimate_atomrefs,
SplittingStrategy,
RandomSplit,
)
Expand Down Expand Up @@ -127,6 +128,7 @@ def __init__(
self.property_units = property_units
self.distance_unit = distance_unit
self._stats = {}
self._atomrefs = {}
self._is_setup = False
self.data_workdir = data_workdir
self.cleanup_workdir_stage = cleanup_workdir_stage
Expand Down Expand Up @@ -359,6 +361,21 @@ def get_stats(
self._stats[key] = stats
return stats

def get_atomrefs(
self, property: str, divide_by_atoms: bool

) -> Tuple[torch.Tensor, torch.Tensor]:
key = (property, divide_by_atoms)
if key in self._stats:
return self._stats[key]

atomrefs = estimate_atomrefs(
self.train_dataloader(),
divide_by_atoms={property: divide_by_atoms},
)[property]
self._atomrefs[key] = atomrefs
return atomrefs

@property
def train_dataset(self) -> BaseAtomsData:
return self._train_dataset
Expand Down
51 changes: 46 additions & 5 deletions src/schnetpack/data/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
import torch
from tqdm import tqdm

import schnetpack.properties as structure
import schnetpack.properties as properties
from schnetpack.data import AtomsLoader

__all__ = ["calculate_stats"]
__all__ = ["calculate_stats", "estimate_atomrefs"]


def calculate_stats(
Expand Down Expand Up @@ -47,8 +47,8 @@ def calculate_stats(
val = props[p][None, :]
if atomref and p in atomref.keys():
ar = atomref[p]
ar = ar[props[structure.Z]]
idx_m = props[structure.idx_m]
ar = ar[props[properties.Z]]
idx_m = props[properties.idx_m]
tmp = torch.zeros((idx_m[-1] + 1,), dtype=ar.dtype, device=ar.device)
v0 = tmp.index_add(0, idx_m, ar)
val -= v0
Expand All @@ -59,7 +59,7 @@ def calculate_stats(
batch_size = sample_values.shape[1]
new_count = count + batch_size

norm = norm_mask[:, None] * props[structure.n_atoms][None, :] + (
norm = norm_mask[:, None] * props[properties.n_atoms][None, :] + (
1 - norm_mask[:, None]
)
sample_values /= norm
Expand All @@ -76,3 +76,44 @@ def calculate_stats(
stddev = torch.sqrt(M2 / count)
stats = {pn: (mu, std) for pn, mu, std in zip(property_names, mean, stddev)}
return stats


def estimate_atomrefs(loader, divide_by_atoms, z_max=100):
property_names = list(divide_by_atoms.keys())
n_data = len(loader.dataset)
all_properties = {pname: torch.zeros(n_data) for pname in property_names}
all_atom_types = torch.zeros((n_data, z_max))
data_counter = 0


# loop over all batches
for batch in tqdm(loader):
# load data
idx_m = batch[properties.idx_m]
atomic_numbers = batch[properties.Z]

# get counts for atomic numbers
unique_ids = torch.unique(idx_m)
for i in unique_ids:
atomic_numbers_i = atomic_numbers[idx_m == i]
atom_types, atom_counts = torch.unique(atomic_numbers_i, return_counts=True)
# save atom counts and properties
for atom_type, atom_count in zip(atom_types, atom_counts):
all_atom_types[data_counter, atom_type] = atom_count
for pname in property_names:
all_properties[pname][data_counter] = batch[pname][i]
data_counter += 1

# perform linear regression to get the elementwise energy contributions
existing_atom_types = torch.where(all_atom_types.sum(axis=0) != 0)[0]
X = torch.squeeze(all_atom_types[:, existing_atom_types])
y = all_properties["energy_U0"]
w = torch.linalg.inv(X.T @ X) @ X.T @ y

# compute energy estimates
elementwise_contributions = torch.zeros((z_max))
for atom_type, weight in zip(existing_atom_types, w):
elementwise_contributions[atom_type] = weight
energy_estimates = torch.sum(all_atom_types * elementwise_contributions, axis=1)

return energy_estimates
10 changes: 7 additions & 3 deletions src/schnetpack/transform/atomistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@ def datamodule(self, _datamodule):
"""
if self.remove_atomrefs and not self._atomrefs_initialized:
atrefs = _datamodule.train_dataset.atomrefs
if atrefs[self._property] is None:
atrefs = _datamodule.get_atomrefs(self._property, self.is_extensive)
self.atomref = atrefs[self._property].detach()

if self.remove_mean and not self._mean_initialized:
Expand Down Expand Up @@ -264,13 +266,15 @@ def __init__(
self.register_buffer("atomref", atomrefs)
self.register_buffer("mean", property_mean)

def datamodule(self, value):
def datamodule(self, _datamodule):
if self.add_atomrefs and not self._atomrefs_initialized:
atrefs = value.train_dataset.atomrefs
atrefs = _datamodule.train_dataset.atomrefs
if atrefs[self._property] is None:
atrefs = _datamodule.get_atomrefs(self._property, self.is_extensive)
self.atomref = atrefs[self._property].detach()

if self.add_mean and not self._mean_initialized:
stats = value.get_stats(
stats = _datamodule.get_stats(
self._property, self.is_extensive, self.add_atomrefs
)
self.mean = stats[0].detach()
Expand Down

0 comments on commit 6a2467e

Please sign in to comment.