Skip to content

Commit

Permalink
Merge pull request #335 from laserkelvin/final-final-edge-fix
Browse files Browse the repository at this point in the history
(Pen)ultimate edge logic fix
  • Loading branch information
laserkelvin authored Dec 20, 2024
2 parents 8a0765d + d4e9e36 commit a28a5f5
Showing 1 changed file with 62 additions and 55 deletions.
117 changes: 62 additions & 55 deletions matsciml/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,12 +607,11 @@ def element_types():


class Edge:
def __init__(
self, src: int, dst: int, image: np.ndarray, is_undirected: bool = False
):
def __init__(self, src: int, dst: int, image: np.ndarray):
"""
Initializes the Edge object with the source, destination,
and image, ensuring directionality and handling self-loops.
Parameters
----------
src : int
Expand All @@ -622,17 +621,15 @@ def __init__(
image : np.ndarray
1D vector of three elements as a ``np.ndarray``.
"""
if is_undirected:
if src > dst:
# Enforce directed order
src, dst, image = dst, src, -image
if src > dst:
# Enforce directed order
src, dst, image = dst, src, -image
if src == dst:
# For self-loops, enforce a canonical form of the image
image = self._canonicalize_image(image)
self.src = src
self.dst = dst
self.image = image
self.is_undirected = is_undirected

@staticmethod
def _canonicalize_image(image: np.ndarray) -> np.ndarray:
Expand Down Expand Up @@ -686,10 +683,7 @@ def __eq__(self, other) -> bool:
"""
if not isinstance(other, Edge):
return False
if self.is_undirected:
node_eq = self.directed_index == other.directed_index
else:
node_eq = (self.src == other.src) and (self.dst == other.dst)
node_eq = self.directed_index == other.directed_index
return node_eq and np.array_equal(self.image, other.image)

def __str__(self) -> str:
Expand Down Expand Up @@ -854,32 +848,41 @@ def _all_sites_have_neighbors(neighbors):
raise ValueError(
f"No neighbors detected for structure with cutoff {cutoff}; {structure}"
)
keep = set()
# only keeps undirected edges that are unique through set
for src_idx, dst_sites in enumerate(neighbors):
for site in dst_sites:
keep.add(
Edge(
src_idx,
site.index,
np.array(site.image),
is_undirected,
# if we assume undirected edges, apply a filter
if is_undirected:
keep = set()
# only keeps undirected edges that are unique through set
for src_idx, dst_sites in enumerate(neighbors):
for site in dst_sites:
keep.add(
Edge(
src_idx,
site.index,
np.array(site.image),
)
)
)
# now only keep the edges after the first loop
all_src, all_dst, all_images = [], [], []
num_atoms = len(structure.atomic_numbers)
counter = {index: 0 for index in range(num_atoms)}
for edge in keep:
# stop adding edges if either src/dst have accumulated enough neighbors
if counter[edge.src] > max_neighbors or counter[edge.dst] > max_neighbors:
pass
else:
all_src.append(edge.src)
all_dst.append(edge.dst)
all_images.append(edge.image)
counter[edge.src] += 1
counter[edge.dst] += 1
# now only keep the edges after the first loop
all_src, all_dst, all_images = [], [], []
num_atoms = len(structure.atomic_numbers)
counter = {index: 0 for index in range(num_atoms)}
for edge in keep:
# stop adding edges if either src/dst have accumulated enough neighbors
if counter[edge.src] > max_neighbors or counter[edge.dst] > max_neighbors:
pass
else:
all_src.append(edge.src)
all_dst.append(edge.dst)
all_images.append(edge.image)
counter[edge.src] += 1
counter[edge.dst] += 1
# alternatively, just add the edges as is from pymatgen
else:
all_src, all_dst, all_images = [], [], []
for src_idx, dst_sites in enumerate(neighbors):
for site in dst_sites:
all_src.append(src_idx)
all_dst.append(site.index)
all_images.append(site.image)
if any([len(obj) == 0 for obj in [all_src, all_dst, all_images]]):
raise ValueError(
f"No images or edges to work off for cutoff {cutoff}."
Expand Down Expand Up @@ -963,24 +966,28 @@ def calculate_ase_periodic_shifts(
)
# not really needed but good sanity check
assert np.all(distances <= cutoff_radius)
keep = set()
# only keeps undirected edges that are unique
for src, dst, image in zip(all_src, all_dst, all_images):
keep.add(Edge(src=src, dst=dst, image=image, is_undirected=is_undirected))

all_src, all_dst, all_images = [], [], []
num_atoms = len(atoms)
counter = {index: 0 for index in range(num_atoms)}
for edge in keep:
# obey max_neighbors by not adding any more edges
if counter[edge.src] > max_neighbors or counter[edge.dst] > max_neighbors:
pass
else:
all_src.append(edge.src)
all_dst.append(edge.dst)
all_images.append(edge.image)
counter[edge.src] += 1
counter[edge.dst] += 1

# in the undirected case, we will filter out
# half of the edges as src/dst == dst/src for a given image
if is_undirected:
keep = set()
# only keeps undirected edges that are unique
for src, dst, image in zip(all_src, all_dst, all_images):
keep.add(Edge(src=src, dst=dst, image=image))

all_src, all_dst, all_images = [], [], []
num_atoms = len(atoms)
counter = {index: 0 for index in range(num_atoms)}
for edge in keep:
# obey max_neighbors by not adding any more edges
if counter[edge.src] > max_neighbors or counter[edge.dst] > max_neighbors:
pass
else:
all_src.append(edge.src)
all_dst.append(edge.dst)
all_images.append(edge.image)
counter[edge.src] += 1
counter[edge.dst] += 1

frac_coords = torch.from_numpy(atoms.get_scaled_positions()).float()
coords = torch.from_numpy(atoms.positions).float()
Expand Down

0 comments on commit a28a5f5

Please sign in to comment.