Skip to content

Commit

Permalink
fix & refactor: making ase pass tests by returning self-loops and obe…
Browse files Browse the repository at this point in the history
…ying max_neighbors
  • Loading branch information
laserkelvin committed Nov 25, 2024
1 parent ad6e146 commit 1101ae7
Showing 1 changed file with 14 additions and 4 deletions.
18 changes: 14 additions & 4 deletions matsciml/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -884,7 +884,7 @@ def calculate_ase_periodic_shifts(
pbc=(True, True, True),
)
all_src, all_dst, distances, all_images = neighbor_list(
"ijdS", atoms, cutoff=cutoff_radius, self_interaction=False
"ijdS", atoms, cutoff=cutoff_radius, self_interaction=True
)
# not really needed but good sanity check
assert np.all(distances <= cutoff_radius)
Expand All @@ -893,9 +893,19 @@ def calculate_ase_periodic_shifts(
for src, dst, image in zip(all_src, all_dst, all_images):
keep.add(Edge(src, dst, image))

all_src = [edge.src for edge in keep]
all_dst = [edge.dst for edge in keep]
all_images = [edge.image for edge in keep]
all_src, all_dst, all_images = [], [], []
num_atoms = len(atoms)
counter = {index: 0 for index in range(num_atoms)}
for edge in keep:
# obey max_neighbors by not adding any more edges
if counter[edge.src] > max_neighbors or counter[edge.dst] > max_neighbors:
pass
else:
all_src.append(edge.src)
all_dst.append(edge.dst)
all_images.append(edge.image)
counter[edge.src] += 1
counter[edge.dst] += 1

frac_coords = torch.from_numpy(atoms.get_scaled_positions()).float()
coords = torch.from_numpy(atoms.positions).float()
Expand Down

0 comments on commit 1101ae7

Please sign in to comment.