diff --git a/matsciml/datasets/utils.py b/matsciml/datasets/utils.py index 0b2ca9d1..909c988e 100644 --- a/matsciml/datasets/utils.py +++ b/matsciml/datasets/utils.py @@ -884,7 +884,7 @@ def calculate_ase_periodic_shifts( pbc=(True, True, True), ) all_src, all_dst, distances, all_images = neighbor_list( - "ijdS", atoms, cutoff=cutoff_radius, self_interaction=False + "ijdS", atoms, cutoff=cutoff_radius, self_interaction=True ) # not really needed but good sanity check assert np.all(distances <= cutoff_radius) @@ -893,9 +893,19 @@ def calculate_ase_periodic_shifts( for src, dst, image in zip(all_src, all_dst, all_images): keep.add(Edge(src, dst, image)) - all_src = [edge.src for edge in keep] - all_dst = [edge.dst for edge in keep] - all_images = [edge.image for edge in keep] + all_src, all_dst, all_images = [], [], [] + num_atoms = len(atoms) + counter = {index: 0 for index in range(num_atoms)} + for edge in keep: + # obey max_neighbors by not adding any more edges + if counter[edge.src] > max_neighbors or counter[edge.dst] > max_neighbors: + pass + else: + all_src.append(edge.src) + all_dst.append(edge.dst) + all_images.append(edge.image) + counter[edge.src] += 1 + counter[edge.dst] += 1 frac_coords = torch.from_numpy(atoms.get_scaled_positions()).float() coords = torch.from_numpy(atoms.positions).float()