Skip to content

Commit

Permalink
run black
Browse files Browse the repository at this point in the history
  • Loading branch information
stefaanhessmann committed Aug 28, 2024
1 parent c0fce7a commit e5564aa
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions src/schnetpack/data/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ def calculate_stats(
sample_values.append(val)
sample_values = torch.cat(sample_values, dim=0)


batch_size = sample_values.shape[1]
new_count = count + batch_size

Expand Down Expand Up @@ -127,10 +126,16 @@ def estimate_atomrefs(dataloader, is_extensive, z_max=100):
if is_extensive[pname]:
w[pname] = torch.linalg.inv(X.T @ X) @ X.T @ all_properties[pname]
else:
w[pname] = torch.linalg.inv(X.T @ X) @ X.T @ (all_properties[pname] / X.sum(axis=1))
w[pname] = (
torch.linalg.inv(X.T @ X)
@ X.T
@ (all_properties[pname] / X.sum(axis=1))
)

# compute energy estimates
elementwise_contributions = {pname: torch.zeros((z_max)) for pname in property_names}
elementwise_contributions = {
pname: torch.zeros((z_max)) for pname in property_names
}
for pname in property_names:
for atom_type, weight in zip(existing_atom_types, w[pname]):
elementwise_contributions[pname][atom_type] = weight
Expand Down

0 comments on commit e5564aa

Please sign in to comment.