Skip to content

Commit

Permalink
Merge pull request #333 from laserkelvin/final-edge-creation
Browse files Browse the repository at this point in the history
Final edge creation
  • Loading branch information
laserkelvin authored Dec 18, 2024
2 parents ebe17f0 + 058273c commit 8a0765d
Show file tree
Hide file tree
Showing 3 changed files with 152 additions and 70 deletions.
40 changes: 40 additions & 0 deletions matsciml/datasets/tests/test_edge_logic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from __future__ import annotations

import pytest
import numpy as np

from matsciml.datasets.utils import Edge


def test_non_self_interaction():
"""
These two nodes edges should not be equivalent, since they
are not self-interactions and the images are different
"""
a = Edge(src=0, dst=10, image=np.array([-1, 0, 0]))
b = Edge(src=0, dst=10, image=np.array([1, 0, 0]))
assert a != b


def test_self_interaction_image():
"""
These two edges are mirror images of one another since
the src/dst are the same node.
"""
a = Edge(src=0, dst=0, image=np.array([-1, 0, 0]))
b = Edge(src=0, dst=0, image=np.array([1, 0, 0]))
assert a == b


@pytest.mark.parametrize("is_undirected", [True, False])
def test_directed_edges(is_undirected):
"""
These two are the same edge in the undirected case,
but are different if treating directed graphs
"""
a = Edge(src=5, dst=10, image=np.array([0, 0, 0]), is_undirected=is_undirected)
b = Edge(src=10, dst=5, image=np.array([0, 0, 0]), is_undirected=is_undirected)
if is_undirected:
assert a == b
else:
assert a != b
15 changes: 13 additions & 2 deletions matsciml/datasets/transforms/pbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def __init__(
allow_self_loops: bool = False,
convert_to_unit_cell: bool = False,
is_cartesian: bool | None = None,
is_undirected: bool = False,
) -> None:
"""
Rewires an already present graph to include periodic boundary conditions.
Expand Down Expand Up @@ -90,6 +91,7 @@ def __init__(
)
self.is_cartesian = is_cartesian
self.convert_to_unit_cell = convert_to_unit_cell
self.is_undirected = is_undirected

def __call__(self, data: DataDict) -> DataDict:
"""
Expand Down Expand Up @@ -135,6 +137,7 @@ def __call__(self, data: DataDict) -> DataDict:
self.cutoff_radius,
self.adaptive_cutoff,
max_neighbors=self.max_neighbors,
is_undirected=self.is_undirected,
)
data.update(graph_props)
return data
Expand Down Expand Up @@ -179,11 +182,19 @@ def __call__(self, data: DataDict) -> DataDict:
is_cartesian=self.is_cartesian,
)
graph_props = calculate_periodic_shifts(
structure, self.cutoff_radius, self.adaptive_cutoff, self.max_neighbors
structure,
self.cutoff_radius,
self.adaptive_cutoff,
self.max_neighbors,
self.is_undirected,
)
elif self.backend == "ase":
graph_props = calculate_ase_periodic_shifts(
data, self.cutoff_radius, self.adaptive_cutoff, self.max_neighbors
data,
self.cutoff_radius,
self.adaptive_cutoff,
self.max_neighbors,
self.is_undirected,
)
else:
raise RuntimeError(f"Requested backend f{self.backend} not available.")
Expand Down
167 changes: 99 additions & 68 deletions matsciml/datasets/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import pickle
from dataclasses import dataclass
from collections.abc import Generator
from functools import lru_cache, partial
from os import makedirs
Expand Down Expand Up @@ -607,91 +606,114 @@ def element_types():
return list(atomic_number_map().keys())


@dataclass
class Edge:
"""
Implements a data structure for edge redundancy comparison
with a syntactic sugar.
Implements a ``sorted_index`` property to returns a pair
of indices for the edge, irrespective of direction. This,
in addition to the ``image`` of the edge is used in the
``__eq__`` comparison.
Finally, ``__hash__`` is based off the string representation
of this object, making it hashable and usable in sets.
Attributes
----------
src : int
Index of the source node of the edge.
dst : int
Index of the destination node of the edge.
image : np.ndarray
1D vector of three elements as a ``np.ndarray``.
"""

src: int
dst: int
image: np.ndarray

@property
def sorted_index(self) -> tuple[int, int]:
return (min(self.src, self.dst), max(self.src, self.dst))
def __init__(
self, src: int, dst: int, image: np.ndarray, is_undirected: bool = False
):
"""
Initializes the Edge object with the source, destination,
and image, ensuring directionality and handling self-loops.
Parameters
----------
src : int
Index of the source node of the edge.
dst : int
Index of the destination node of the edge.
image : np.ndarray
1D vector of three elements as a ``np.ndarray``.
"""
if is_undirected:
if src > dst:
# Enforce directed order
src, dst, image = dst, src, -image
if src == dst:
# For self-loops, enforce a canonical form of the image
image = self._canonicalize_image(image)
self.src = src
self.dst = dst
self.image = image
self.is_undirected = is_undirected

@staticmethod
def _canonicalize_image(image: np.ndarray) -> np.ndarray:
"""
Canonicalizes the image vector for self-loops to ensure consistent
representation, eliminating mirrored duplicates.
Parameters
----------
image : np.ndarray
The image vector to canonicalize.
Returns
-------
np.ndarray
The canonicalized image vector.
"""
# Find the first non-zero component
for i in range(len(image)):
if image[i] != 0:
# Flip the image if the first non-zero component is negative
if image[i] < 0:
return -image
break
return image

@property
def unsigned_image(self) -> np.ndarray:
def directed_index(self) -> tuple[int, int]:
"""
Returns the absolute image indices.
This is used when considering parity-inversion,
i.e. two edges are equivalent if they are in
in mirroring cell images.
Returns the directed pair of indices for the edge.
Returns
-------
np.ndarray
Indices without parity
tuple[int, int]
The pair (src, dst) with src < dst.
"""
return np.abs(self.image)
return self.src, self.dst

def __eq__(self, other: Edge) -> bool:
index_eq = self.sorted_index == other.sorted_index
image_eq = np.all(self.unsigned_image == other.unsigned_image)
return all([index_eq, image_eq])
def __eq__(self, other) -> bool:
"""
Compares two edges for equality, considering both the
directed (src, dst) pair and the image.
def __str__(self) -> str:
"""Represents the edge without phase or parity information. Mainly for hashing."""
return f"Sorted src/dst: {self.sorted_index}, |image|: {self.unsigned_image}"
Parameters
----------
other : Edge
The other edge to compare against.
def __hash__(self) -> int:
Returns
-------
bool
True if the edges are equivalent, otherwise False.
"""
This hash method is primarily intended for use in
``set`` comparisons to check for uniqueness.
if not isinstance(other, Edge):
return False
if self.is_undirected:
node_eq = self.directed_index == other.directed_index
else:
node_eq = (self.src == other.src) and (self.dst == other.dst)
return node_eq and np.array_equal(self.image, other.image)

The general idea for edge-uniqueness is assuming
the undirected case where ``src`` and ``dst``
is exchangable, and when not considering phase
for image indices.
def __str__(self) -> str:
"""
Represents the edge with the directed pair (src, dst) and image.
As an example, the following two edges are
equivalent as they have interchangable ``src``
and ``dst`` indices, **and** the displacement
owing to image offsets is the same - just in
opposite directions,
Returns
-------
str
A string representation of the edge.
"""
return f"Directed src/dst: {self.directed_index}, image: {self.image}"

```
Edge(src=1, dst=7, image=[-1, 0, 0])
Edge(src=7, dst=1, image[1, 0, 0])
```
def __hash__(self) -> int:
"""
Hashes the edge for use in sets and dicts.
The hash is based on the directed (src, dst) pair and the image.
Returns
-------
int
Permutation invariant hash of this edge.
The hash of this edge.
"""

return hash(str(self))
return hash((self.directed_index, tuple(self.image)))


def make_pymatgen_periodic_structure(
Expand Down Expand Up @@ -765,6 +787,7 @@ def calculate_periodic_shifts(
cutoff: float,
adaptive_cutoff: bool = False,
max_neighbors: int = 1000,
is_undirected: bool = False,
) -> dict[str, torch.Tensor]:
"""
Compute properties with respect to periodic boundary conditions.
Expand Down Expand Up @@ -835,7 +858,14 @@ def _all_sites_have_neighbors(neighbors):
# only keeps undirected edges that are unique through set
for src_idx, dst_sites in enumerate(neighbors):
for site in dst_sites:
keep.add(Edge(src_idx, site.index, np.array(site.image)))
keep.add(
Edge(
src_idx,
site.index,
np.array(site.image),
is_undirected,
)
)
# now only keep the edges after the first loop
all_src, all_dst, all_images = [], [], []
num_atoms = len(structure.atomic_numbers)
Expand Down Expand Up @@ -883,6 +913,7 @@ def calculate_ase_periodic_shifts(
cutoff_radius: float,
adaptive_cutoff: bool,
max_neighbors: int = 1000,
is_undirected: bool = False,
) -> dict[str, torch.Tensor]:
"""
Calculate edges for the system using ``ase`` routines.
Expand Down Expand Up @@ -935,7 +966,7 @@ def calculate_ase_periodic_shifts(
keep = set()
# only keeps undirected edges that are unique
for src, dst, image in zip(all_src, all_dst, all_images):
keep.add(Edge(src, dst, image))
keep.add(Edge(src=src, dst=dst, image=image, is_undirected=is_undirected))

all_src, all_dst, all_images = [], [], []
num_atoms = len(atoms)
Expand Down

0 comments on commit 8a0765d

Please sign in to comment.