Skip to content

Commit e5564aa

Browse files
run black
1 parent c0fce7a commit e5564aa

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

src/schnetpack/data/stats.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@ def calculate_stats(
5656
sample_values.append(val)
5757
sample_values = torch.cat(sample_values, dim=0)
5858

59-
6059
batch_size = sample_values.shape[1]
6160
new_count = count + batch_size
6261

@@ -127,10 +126,16 @@ def estimate_atomrefs(dataloader, is_extensive, z_max=100):
127126
if is_extensive[pname]:
128127
w[pname] = torch.linalg.inv(X.T @ X) @ X.T @ all_properties[pname]
129128
else:
130-
w[pname] = torch.linalg.inv(X.T @ X) @ X.T @ (all_properties[pname] / X.sum(axis=1))
129+
w[pname] = (
130+
torch.linalg.inv(X.T @ X)
131+
@ X.T
132+
@ (all_properties[pname] / X.sum(axis=1))
133+
)
131134

132135
# compute energy estimates
133-
elementwise_contributions = {pname: torch.zeros((z_max)) for pname in property_names}
136+
elementwise_contributions = {
137+
pname: torch.zeros((z_max)) for pname in property_names
138+
}
134139
for pname in property_names:
135140
for atom_type, weight in zip(existing_atom_types, w[pname]):
136141
elementwise_contributions[pname][atom_type] = weight

0 commit comments

Comments
 (0)