Skip to content

Commit

Permalink
refactor: removing is_undirected from edge class definition
Browse files Browse the repository at this point in the history
Not needed now since the branching is in ase/pymatgen routines now
  • Loading branch information
laserkelvin committed Dec 20, 2024
1 parent a40db35 commit 3f208a1
Showing 1 changed file with 7 additions and 14 deletions.
21 changes: 7 additions & 14 deletions matsciml/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,12 +607,11 @@ def element_types():


class Edge:
def __init__(
self, src: int, dst: int, image: np.ndarray, is_undirected: bool = False
):
def __init__(self, src: int, dst: int, image: np.ndarray):
"""
Initializes the Edge object with the source, destination,
and image, ensuring directionality and handling self-loops.
Parameters
----------
src : int
Expand All @@ -622,17 +621,15 @@ def __init__(
image : np.ndarray
1D vector of three elements as a ``np.ndarray``.
"""
if is_undirected:
if src > dst:
# Enforce directed order
src, dst, image = dst, src, -image
if src > dst:
# Enforce directed order
src, dst, image = dst, src, -image
if src == dst:
# For self-loops, enforce a canonical form of the image
image = self._canonicalize_image(image)
self.src = src
self.dst = dst
self.image = image
self.is_undirected = is_undirected

@staticmethod
def _canonicalize_image(image: np.ndarray) -> np.ndarray:
Expand Down Expand Up @@ -686,10 +683,7 @@ def __eq__(self, other) -> bool:
"""
if not isinstance(other, Edge):
return False
if self.is_undirected:
node_eq = self.directed_index == other.directed_index
else:
node_eq = (self.src == other.src) and (self.dst == other.dst)
node_eq = self.directed_index == other.directed_index
return node_eq and np.array_equal(self.image, other.image)

def __str__(self) -> str:
Expand Down Expand Up @@ -865,7 +859,6 @@ def _all_sites_have_neighbors(neighbors):
src_idx,
site.index,
np.array(site.image),
is_undirected,
)
)
# now only keep the edges after the first loop
Expand Down Expand Up @@ -980,7 +973,7 @@ def calculate_ase_periodic_shifts(
keep = set()
# only keeps undirected edges that are unique
for src, dst, image in zip(all_src, all_dst, all_images):
keep.add(Edge(src=src, dst=dst, image=image, is_undirected=is_undirected))
keep.add(Edge(src=src, dst=dst, image=image))

all_src, all_dst, all_images = [], [], []
num_atoms = len(atoms)
Expand Down

0 comments on commit 3f208a1

Please sign in to comment.