Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Specifiable options for periodic neighbors calculations #318

Merged
merged 27 commits into from
Nov 18, 2024
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
5fea3a8
refactor: adding max_neighbors kwarg to pbc transform
laserkelvin Nov 13, 2024
df21736
refactor: allow pymatgen neighbors calculation to truncate max neighbors
laserkelvin Nov 13, 2024
4fe3b20
refactor & docs: adding docstring to ase function counterpart, and ad…
laserkelvin Nov 13, 2024
228c2f8
refactor: adding max neighbor condition check in ase neighbor calcula…
laserkelvin Nov 13, 2024
4ca6a87
fix: correct variable name in ase cutoff check
laserkelvin Nov 13, 2024
0999984
refactor: passing max neighbors into respective functions within tran…
laserkelvin Nov 13, 2024
566f3e0
refactor: adding self loop option in transform
laserkelvin Nov 13, 2024
e485eb0
refactor: added masking operation for optional self loop removal
laserkelvin Nov 13, 2024
7437fc6
fix: correcting order of data update for self loop check
laserkelvin Nov 13, 2024
555eb8c
fix: correcting loop break condition
laserkelvin Nov 14, 2024
beb4418
fix: added missing reshape in ase neighbor calculation
laserkelvin Nov 14, 2024
84aa0c4
test: added general unit test for edge calculations
laserkelvin Nov 14, 2024
6828dd3
feat: added fractional cartesian conversion function
laserkelvin Nov 14, 2024
ae87d32
refactor: removing assert statement and just using matmul
laserkelvin Nov 14, 2024
272c156
feat: added function for building images exhaustively
laserkelvin Nov 15, 2024
4a66623
test: temporarily removing failing configuration
laserkelvin Nov 15, 2024
bc6d73f
refactor: allowing options to manually specify pymatgen periodic stru…
laserkelvin Nov 17, 2024
2f48299
refactor: making to_unit_cell arg default to None
laserkelvin Nov 17, 2024
b5252cf
refactor: adding warning to is_cartesian if backend is ase
laserkelvin Nov 17, 2024
64e225d
fix: correcting doubled up all_images check
laserkelvin Nov 18, 2024
1d219c3
fix: missing max_neighbors in shortcut periodic calculation
laserkelvin Nov 18, 2024
34f5cb8
refactor: lowering fractional coordinate check to warning
laserkelvin Nov 18, 2024
61bbf1d
fix: reshaping cell tensor for ase calculation if needed
laserkelvin Nov 18, 2024
cd05acc
refactor: adding fractional coords into warning message
laserkelvin Nov 18, 2024
973a1b9
docs: added link to wiki equation for rotation matrix
laserkelvin Nov 18, 2024
762d9be
docs: updated pbc sphinx docs
laserkelvin Nov 18, 2024
f1b14ac
docs: revamped docstring for periodic boundary condition transform
laserkelvin Nov 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 28 additions & 3 deletions matsciml/datasets/transforms/pbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy as np
import torch
from pymatgen.core import Lattice, Structure
from loguru import logger

from matsciml.common.types import DataDict
from matsciml.datasets.transforms.base import AbstractDataTransform
Expand Down Expand Up @@ -38,11 +39,23 @@ def __init__(
cutoff_radius: float,
adaptive_cutoff: bool = False,
backend: Literal["pymatgen", "ase"] = "pymatgen",
max_neighbors: int = 1000,
allow_self_loops: bool = False,
convert_to_unit_cell: bool = False,
is_cartesian: bool | None = None,
) -> None:
super().__init__()
self.cutoff_radius = cutoff_radius
self.adaptive_cutoff = adaptive_cutoff
self.backend = backend
self.max_neighbors = max_neighbors
self.allow_self_loops = allow_self_loops
if is_cartesian is not None and backend == "ase":
logger.warning(
"`is_cartesian` passed but using `ase` backend; option will not affect anything."
)
self.is_cartesian = is_cartesian
self.convert_to_unit_cell = convert_to_unit_cell

def __call__(self, data: DataDict) -> DataDict:
"""
Expand Down Expand Up @@ -84,7 +97,10 @@ def __call__(self, data: DataDict) -> DataDict:
structure = data["structure"]
if isinstance(structure, Structure):
graph_props = calculate_periodic_shifts(
structure, self.cutoff_radius, self.adaptive_cutoff
structure,
self.cutoff_radius,
self.adaptive_cutoff,
max_neighbors=self.max_neighbors,
)
data.update(graph_props)
return data
Expand Down Expand Up @@ -123,16 +139,25 @@ def __call__(self, data: DataDict) -> DataDict:
data["atomic_numbers"],
data["pos"],
lattice=lattice,
convert_to_unit_cell=self.convert_to_unit_cell,
is_cartesian=self.is_cartesian,
)
if self.backend == "pymatgen":
graph_props = calculate_periodic_shifts(
structure, self.cutoff_radius, self.adaptive_cutoff
structure, self.cutoff_radius, self.adaptive_cutoff, self.max_neighbors
)
elif self.backend == "ase":
graph_props = calculate_ase_periodic_shifts(
data, self.cutoff_radius, self.adaptive_cutoff
data, self.cutoff_radius, self.adaptive_cutoff, self.max_neighbors
)
else:
raise RuntimeError(f"Requested backend f{self.backend} not available.")
data.update(graph_props)
if not self.allow_self_loops:
mask = data["src_nodes"] == data["dst_nodes"]
# only mask out self-loops within the same image
mask &= data["unit_offsets"].sum(dim=-1) == 0
# apply mask to each of the tensors that depend on edges
for key in ["src_nodes", "dst_nodes", "images", "unit_offsets", "offsets"]:
data[key] = data[key][mask]
return data
87 changes: 87 additions & 0 deletions matsciml/datasets/transforms/tests/test_pbc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from __future__ import annotations

from collections import Counter

import torch
import pytest
import numpy as np
from pymatgen.core import Structure, Lattice

from matsciml.datasets.transforms import PeriodicPropertiesTransform

"""
This module uses reference Materials project structures and tests
the edge calculation routines to ensure they at least work with
various parameters.

The key thing here is at least using feasible structures to perform
this check, rather than using randomly generated coordinates and
lattices, even if composing them isn't meaningful.
"""

hexa = Lattice.from_parameters(
4.81, 4.809999999999999, 13.12, 90.0, 90.0, 120.0, vesta=True
)
cubic = Lattice.from_parameters(6.79, 6.79, 12.63, 90.0, 90.0, 90.0, vesta=True)

# mp-1143
alumina = Structure(
hexa,
species=["Al", "O"],
coords=[[1 / 3, 2 / 3, 0.814571], [0.360521, 1 / 3, 0.583333]],
coords_are_cartesian=False,
)
# mp-1267
nac = Structure(
cubic,
species=["Na", "C"],
coords=[[0.688819, 3 / 4, 3 / 8], [0.065833, 0.565833, 0.0]],
coords_are_cartesian=False,
)


@pytest.mark.parametrize(
"coords",
[
alumina.cart_coords,
nac.cart_coords,
],
)
@pytest.mark.parametrize(
"cell",
[
hexa.matrix,
cubic.matrix,
],
)
@pytest.mark.parametrize("self_loops", [True, False])
@pytest.mark.parametrize("backend", ["pymatgen", "ase"])
@pytest.mark.parametrize(
"cutoff_radius", [6.0, 9.0, 15.0]
) # TODO figure out why pmg fails on 3
def test_periodic_generation(
coords: np.ndarray,
cell: np.ndarray,
self_loops: bool,
backend: str,
cutoff_radius: float,
):
coords = torch.FloatTensor(coords)
cell = torch.FloatTensor(cell)
transform = PeriodicPropertiesTransform(
cutoff_radius=cutoff_radius,
adaptive_cutoff=False,
backend=backend,
max_neighbors=10,
allow_self_loops=self_loops,
)
num_atoms = coords.size(0)
atomic_numbers = torch.ones(num_atoms)
packed_data = {"pos": coords, "cell": cell, "atomic_numbers": atomic_numbers}
output = transform(packed_data)
# check to make sure no source node has more than 10 neighbors
src_nodes = output["src_nodes"].tolist()
counts = Counter(src_nodes)
for index, count in counts.items():
if not self_loops:
assert count < 10, print(f"Node {index} has too many counts. {src_nodes}")
Loading
Loading