diff --git a/matsciml/datasets/utils.py b/matsciml/datasets/utils.py index 74bd6a8f..50413351 100644 --- a/matsciml/datasets/utils.py +++ b/matsciml/datasets/utils.py @@ -607,12 +607,11 @@ def element_types(): class Edge: - def __init__( - self, src: int, dst: int, image: np.ndarray, is_undirected: bool = False - ): + def __init__(self, src: int, dst: int, image: np.ndarray): """ Initializes the Edge object with the source, destination, and image, ensuring directionality and handling self-loops. + Parameters ---------- src : int @@ -622,17 +621,15 @@ def __init__( image : np.ndarray 1D vector of three elements as a ``np.ndarray``. """ - if is_undirected: - if src > dst: - # Enforce directed order - src, dst, image = dst, src, -image + if src > dst: + # Enforce directed order + src, dst, image = dst, src, -image if src == dst: # For self-loops, enforce a canonical form of the image image = self._canonicalize_image(image) self.src = src self.dst = dst self.image = image - self.is_undirected = is_undirected @staticmethod def _canonicalize_image(image: np.ndarray) -> np.ndarray: @@ -686,10 +683,7 @@ def __eq__(self, other) -> bool: """ if not isinstance(other, Edge): return False - if self.is_undirected: - node_eq = self.directed_index == other.directed_index - else: - node_eq = (self.src == other.src) and (self.dst == other.dst) + node_eq = self.directed_index == other.directed_index return node_eq and np.array_equal(self.image, other.image) def __str__(self) -> str: @@ -865,7 +859,6 @@ def _all_sites_have_neighbors(neighbors): src_idx, site.index, np.array(site.image), - is_undirected, ) ) # now only keep the edges after the first loop @@ -980,7 +973,7 @@ def calculate_ase_periodic_shifts( keep = set() # only keeps undirected edges that are unique for src, dst, image in zip(all_src, all_dst, all_images): - keep.add(Edge(src=src, dst=dst, image=image, is_undirected=is_undirected)) + keep.add(Edge(src=src, dst=dst, image=image)) all_src, all_dst, all_images = [], [], [] num_atoms = len(atoms)