diff --git a/docs/source/best-practices.rst b/docs/source/best-practices.rst index 2fbf9a1b..bb8065e1 100644 --- a/docs/source/best-practices.rst +++ b/docs/source/best-practices.rst @@ -20,17 +20,22 @@ The way this is implemented in ``matsciml`` is to include the transform, .. autofunction:: matsciml.datasets.transforms.PeriodicPropertiesTransform -This implementation is heavily based off -the tutorial outlined in the `e3nn documentation`_ where we use ``pymatgen`` -to generate images, and for every atom in the graph, -compute nearest neighbors with some specified radius cutoff. One additional -detail we include in this approach is the ``adaptive_cutoff`` flag: if set to ``True``, will ensure -that all nodes are connected by gradually increasing the radius cutoff up -to a hard coded limit of 100 angstroms. This is intended to facilitate the -a small nominal cutoff, even if some data samples contain (intentionally) -significantly more distant atoms than the average sample. By doing so, we -improve computational efficiency by not needing to consider many more edges -than required. +This implementation was originally based off +the tutorial outlined in the `e3nn documentation`_. We initially provided +an implementation that uses `pymatgen` for the neighborhood calculation, +but have since extended it to use `ase` as well. We find that `ase` is +slightly less ambiguous with coordinate representations, but results from +the two can be mapped to yield the same behavior. In either case, the coordinates +and lattice parameters are passed into their respective backend representations +(i.e. ``ase.Atoms`` and ``pymatgen.Structure``), and subsequently used to +perform the neighborhood calculation to obtain source/destination node indices +for the edges, as well as their associated periodic image indices. + +Below are descriptions of the two algorithms, and links to their source code. + +.. autofunction:: matsciml.datasets.utils.calculate_periodic_shifts + +.. autofunction:: matsciml.datasets.utils.calculate_ase_periodic_shifts Point clouds to graphs ^^^^^^^^^^^^^^^^^^^^^^ diff --git a/matsciml/datasets/transforms/pbc.py b/matsciml/datasets/transforms/pbc.py index 3f6e24dd..eedaacbf 100644 --- a/matsciml/datasets/transforms/pbc.py +++ b/matsciml/datasets/transforms/pbc.py @@ -5,6 +5,7 @@ import numpy as np import torch from pymatgen.core import Lattice, Structure +from loguru import logger from matsciml.common.types import DataDict from matsciml.datasets.transforms.base import AbstractDataTransform @@ -18,31 +19,76 @@ class PeriodicPropertiesTransform(AbstractDataTransform): - """ - Rewires an already present graph to include periodic boundary conditions. - - Since graphs are normally bounded within a unit cell, they may not capture - the necessary dependencies for atoms connected to neighboring cells. This - transform will compute the unit cell, tile it, and then rewire the graph - edges such that it can capture connectivity given a radial cutoff given - in Angstroms. - - Cut off radius is specified in Angstroms. An additional flag, ``adaptive_cutoff``, - allows the cut off value to grow up to 100 angstroms in order to find neighbors. - This allows larger (typically unstable) structures to be modeled without applying - a large cut off for the entire dataset. - """ - def __init__( self, cutoff_radius: float, adaptive_cutoff: bool = False, backend: Literal["pymatgen", "ase"] = "pymatgen", + max_neighbors: int = 1000, + allow_self_loops: bool = False, + convert_to_unit_cell: bool = False, + is_cartesian: bool | None = None, ) -> None: + """ + Rewires an already present graph to include periodic boundary conditions. + + Since graphs are normally bounded within a unit cell, they may not capture + the necessary dependencies for atoms connected to neighboring cells. This + transform will compute the unit cell, tile it, and then rewire the graph + edges such that it can capture connectivity given a radial cutoff given + in Angstroms. + + Cut off radius is specified in Angstroms. An additional flag, ``adaptive_cutoff``, + allows the cut off value to grow up to 100 angstroms in order to find neighbors. + This allows larger (typically unstable) structures to be modeled without applying + a large cut off for the entire dataset. + + Parameters + ---------- + cutoff_radius : float + Cutoff radius to use to truncate the neighbor list calculation. + adaptive_cutoff : bool, default False + If set to ``True``, will allow ``cutoff_radius`` to grow up to + 30 angstroms if there are any disconnected subgraphs present. + This is to allow distant nodes to be captured in some structures + only as needed, keeping the computational requirements low for + other samples within a dataset. + backend : Literal['pymatgen', 'ase'], default 'pymatgen' + Which algorithm to use for the neighbor list calculation. Nominally + settings can be mapped to have the two produce equivalent results. + 'pymatgen' is kept as the default, but at some point 'ase' will + become the default option. See the hosted documentation 'Best practices' + page for details. + max_neighbors : int, default 1000 + Forcibly truncate the number of edges at any given node. Internally, + a counter is used to track the number of destination nodes when + looping over a node's neighbor list; when the counter exceeds this + value we immediately stop counting neighbors for the current node. + allow_self_loops : bool, default False + If ``True``, the edges will include self-interactions within the + original unit cell. If set to ``False``, these self-loops are + purged before returning edges. + convert_to_unit_cell : bool, default False + This argument is specific to ``pymatgen``, which is passed to the + ``to_unit_cell`` argument during the ``Structure`` construction step. + is_cartesian : bool | None, default None + If set to ``None``, we will try and determine if the structure has + fractional coordinates as input or not. If a boolean is provided, + this is passed into the ``pymatgen.Structure`` construction step. + This is specific to ``pymatgen``, and is not used by ``ase``. + """ super().__init__() self.cutoff_radius = cutoff_radius self.adaptive_cutoff = adaptive_cutoff self.backend = backend + self.max_neighbors = max_neighbors + self.allow_self_loops = allow_self_loops + if is_cartesian is not None and backend == "ase": + logger.warning( + "`is_cartesian` passed but using `ase` backend; option will not affect anything." + ) + self.is_cartesian = is_cartesian + self.convert_to_unit_cell = convert_to_unit_cell def __call__(self, data: DataDict) -> DataDict: """ @@ -84,7 +130,10 @@ def __call__(self, data: DataDict) -> DataDict: structure = data["structure"] if isinstance(structure, Structure): graph_props = calculate_periodic_shifts( - structure, self.cutoff_radius, self.adaptive_cutoff + structure, + self.cutoff_radius, + self.adaptive_cutoff, + max_neighbors=self.max_neighbors, ) data.update(graph_props) return data @@ -123,16 +172,25 @@ def __call__(self, data: DataDict) -> DataDict: data["atomic_numbers"], data["pos"], lattice=lattice, + convert_to_unit_cell=self.convert_to_unit_cell, + is_cartesian=self.is_cartesian, ) if self.backend == "pymatgen": graph_props = calculate_periodic_shifts( - structure, self.cutoff_radius, self.adaptive_cutoff + structure, self.cutoff_radius, self.adaptive_cutoff, self.max_neighbors ) elif self.backend == "ase": graph_props = calculate_ase_periodic_shifts( - data, self.cutoff_radius, self.adaptive_cutoff + data, self.cutoff_radius, self.adaptive_cutoff, self.max_neighbors ) else: raise RuntimeError(f"Requested backend f{self.backend} not available.") data.update(graph_props) + if not self.allow_self_loops: + mask = data["src_nodes"] == data["dst_nodes"] + # only mask out self-loops within the same image + mask &= data["unit_offsets"].sum(dim=-1) == 0 + # apply mask to each of the tensors that depend on edges + for key in ["src_nodes", "dst_nodes", "images", "unit_offsets", "offsets"]: + data[key] = data[key][mask] return data diff --git a/matsciml/datasets/transforms/tests/test_pbc.py b/matsciml/datasets/transforms/tests/test_pbc.py new file mode 100644 index 00000000..2878f614 --- /dev/null +++ b/matsciml/datasets/transforms/tests/test_pbc.py @@ -0,0 +1,87 @@ +from __future__ import annotations + +from collections import Counter + +import torch +import pytest +import numpy as np +from pymatgen.core import Structure, Lattice + +from matsciml.datasets.transforms import PeriodicPropertiesTransform + +""" +This module uses reference Materials project structures and tests +the edge calculation routines to ensure they at least work with +various parameters. + +The key thing here is at least using feasible structures to perform +this check, rather than using randomly generated coordinates and +lattices, even if composing them isn't meaningful. +""" + +hexa = Lattice.from_parameters( + 4.81, 4.809999999999999, 13.12, 90.0, 90.0, 120.0, vesta=True +) +cubic = Lattice.from_parameters(6.79, 6.79, 12.63, 90.0, 90.0, 90.0, vesta=True) + +# mp-1143 +alumina = Structure( + hexa, + species=["Al", "O"], + coords=[[1 / 3, 2 / 3, 0.814571], [0.360521, 1 / 3, 0.583333]], + coords_are_cartesian=False, +) +# mp-1267 +nac = Structure( + cubic, + species=["Na", "C"], + coords=[[0.688819, 3 / 4, 3 / 8], [0.065833, 0.565833, 0.0]], + coords_are_cartesian=False, +) + + +@pytest.mark.parametrize( + "coords", + [ + alumina.cart_coords, + nac.cart_coords, + ], +) +@pytest.mark.parametrize( + "cell", + [ + hexa.matrix, + cubic.matrix, + ], +) +@pytest.mark.parametrize("self_loops", [True, False]) +@pytest.mark.parametrize("backend", ["pymatgen", "ase"]) +@pytest.mark.parametrize( + "cutoff_radius", [6.0, 9.0, 15.0] +) # TODO figure out why pmg fails on 3 +def test_periodic_generation( + coords: np.ndarray, + cell: np.ndarray, + self_loops: bool, + backend: str, + cutoff_radius: float, +): + coords = torch.FloatTensor(coords) + cell = torch.FloatTensor(cell) + transform = PeriodicPropertiesTransform( + cutoff_radius=cutoff_radius, + adaptive_cutoff=False, + backend=backend, + max_neighbors=10, + allow_self_loops=self_loops, + ) + num_atoms = coords.size(0) + atomic_numbers = torch.ones(num_atoms) + packed_data = {"pos": coords, "cell": cell, "atomic_numbers": atomic_numbers} + output = transform(packed_data) + # check to make sure no source node has more than 10 neighbors + src_nodes = output["src_nodes"].tolist() + counts = Counter(src_nodes) + for index, count in counts.items(): + if not self_loops: + assert count < 10, print(f"Node {index} has too many counts. {src_nodes}") diff --git a/matsciml/datasets/utils.py b/matsciml/datasets/utils.py index 56184072..fc496873 100644 --- a/matsciml/datasets/utils.py +++ b/matsciml/datasets/utils.py @@ -8,8 +8,10 @@ from os import makedirs from pathlib import Path from typing import Any, Callable +from itertools import product import lmdb +from loguru import logger import torch import numpy as np from einops import einsum, rearrange @@ -609,6 +611,8 @@ def make_pymatgen_periodic_structure( lat_angles: torch.Tensor | None = None, lat_abc: torch.Tensor | None = None, lattice: Lattice | None = None, + convert_to_unit_cell: bool = False, + is_cartesian: bool | None = None, ) -> Structure: """ Construct a Pymatgen structure from available information @@ -630,16 +634,26 @@ def make_pymatgen_periodic_structure( 1D tensor containing three elements for the lattice angles. lat_abc 1D tensor containing three elements for the lattice abc values. - + convert_to_unit_cell : bool, default False + If set to ``True``, the output structure have coordinates transformed + into fractional coordinates. + is_cartesian : bool | None, default None + If ``None``, the workflow makes an educated guess based on whether + coordinates are all within the range of [0, 1] - if not then they + are definitely cartesian. The user can override this by setting it + to ``True`` or ``False``. Returns ------- Structure Periodic structure object """ - if coords.max() > 1.0 or coords.min() < 0.0: - is_frac = False + if is_cartesian is None: + if coords.max() > 1.0 or coords.min() < 0.0: + is_frac = False + else: + is_frac = True else: - is_frac = True + is_frac = not is_cartesian # TODO this is logically confusing if not lattice: if lat_angles is None or lat_abc is None: raise ValueError( @@ -651,14 +665,17 @@ def make_pymatgen_periodic_structure( lattice, atomic_numbers, coords, - to_unit_cell=True, + to_unit_cell=convert_to_unit_cell, coords_are_cartesian=not is_frac, ) return structure def calculate_periodic_shifts( - structure: Structure, cutoff: float, adaptive_cutoff: bool = False + structure: Structure, + cutoff: float, + adaptive_cutoff: bool = False, + max_neighbors: int = 1000, ) -> dict[str, torch.Tensor]: """ Compute properties with respect to periodic boundary conditions. @@ -681,6 +698,10 @@ def calculate_periodic_shifts( Pymatgen periodic structure. cutoff Radial cut off for defining edges. + max_neighbors : int + Maximum number of neighbors a given site can have. This method + will count the number of edges per site, and the loop will + terminate earlier if the count exceeds this value. Returns ------- @@ -698,8 +719,9 @@ def calculate_periodic_shifts( ) # check to make sure the cell definition is valid if np.any(structure.frac_coords > 1.0): - raise ValueError( + logger.warning( f"Structure has fractional coordinates greater than 1! Check structure:\n{structure}" + f"\n fractional coordinates: {structure.frac_coords}" ) def _all_sites_have_neighbors(neighbors): @@ -724,11 +746,16 @@ def _all_sites_have_neighbors(neighbors): all_src, all_dst, all_images = [], [], [] for src_idx, dst_sites in enumerate(neighbors): + site_count = 0 for site in dst_sites: + if site_count > max_neighbors: + break all_src.append(src_idx) all_dst.append(site.index) all_images.append(site.image) - if any([len(obj) == 0 for obj in [all_images, all_dst, all_images]]): + # determine if we terminate the site loop earlier + site_count += 1 + if any([len(obj) == 0 for obj in [all_src, all_dst, all_images]]): raise ValueError( f"No images or edges to work off for cutoff {cutoff}." f" Please inspect your structure and neighbors: {structure} {neighbors} {structure.cart_coords}" @@ -756,7 +783,43 @@ def _all_sites_have_neighbors(neighbors): return return_dict -def calculate_ase_periodic_shifts(data, cutoff_radius, adaptive_cutoff): +def calculate_ase_periodic_shifts( + data: DataDict, + cutoff_radius: float, + adaptive_cutoff: bool, + max_neighbors: int = 1000, +) -> dict[str, torch.Tensor]: + """ + Calculate edges for the system using ``ase`` routines. + + This function will create an ``ase.Atoms`` object from the available data, + which should mirror in functionality to the ``pymatgen`` counterpart of + this function. + + Parameters + ---------- + data : DataDict + Dictionary containing a single data sample. + cutoff_radius : float + Distance to use for the neighborlist calculation. + adaptive_cutoff : bool + Whether to use the adaptive cut off algorithm. In the event + we arrive at a structure with atoms that are too far away + (i.e. a disconnected subgraph), we will progressively increase + the cutoff value. This allows the majority of graphs to have + a smaller cutoff value, while still allowing more troublesome + interactions to be modeled up to a maximum of 30 angstroms. + max_neighbors : int, default 1000 + Set the maximum number of edges a given atom can have. + The edges are not explicitly sorted in this function, + and we terminate the edge addition for a site once the + count exceeds this value. + + Returns + ------- + dict[str, torch.Tensor] + Dictionary containing key/value mappings for periodic properties. + """ cell = data["cell"] atoms = ase.Atoms( @@ -779,7 +842,7 @@ def _all_sites_have_neighbors(neighbors): # if there are sites without neighbors and user requested adaptive # cut off, we'll keep trying if not _all_sites_have_neighbors(neighbors) and adaptive_cutoff: - while not _all_sites_have_neighbors(neighbors) and cutoff < 30.0: + while not _all_sites_have_neighbors(neighbors) and cutoff_radius < 30.0: # increment radial cutoff progressively cutoff_radius += 0.5 cutoff = [cutoff_radius] * atoms.positions.shape[0] @@ -792,13 +855,18 @@ def _all_sites_have_neighbors(neighbors): all_src, all_dst, all_images = [], [], [] for src_idx in range(len(atoms)): + site_count = 0 dst_index, image = nl.get_neighbors(src_idx) for index in range(len(dst_index)): + if site_count > max_neighbors: + break all_src.append(src_idx) all_dst.append(dst_index[index]) all_images.append(image[index]) + # determine if we terminate the site loop earlier + site_count += 1 - if any([len(obj) == 0 for obj in [all_images, all_dst, all_images]]): + if any([len(obj) == 0 for obj in [all_src, all_dst, all_images]]): raise ValueError( f"No images or edges to work off for cutoff {cutoff}." f" Please inspect your atoms object and neighbors: {atoms}." @@ -815,9 +883,142 @@ def _all_sites_have_neighbors(neighbors): "pos": coords, } + # only do the reshape if we are missing a dimension + if cell.ndim == 2: + cell = rearrange(cell, "i j -> () i j") return_dict["offsets"] = einsum(return_dict["images"], cell, "v i, n i j -> v j") src, dst = return_dict["src_nodes"], return_dict["dst_nodes"] return_dict["unit_offsets"] = ( frac_coords[dst] - frac_coords[src] + return_dict["offsets"] ) return return_dict + + +def cart_frac_conversion( + coords: torch.Tensor, + a: float, + b: float, + c: float, + alpha: float, + beta: float, + gamma: float, + angles_are_degrees: bool = True, + to_fractional: bool = True, +) -> torch.Tensor: + """ + Convert coordinates from cartesians to fractional, or vice versa. + + Distances are expected to be in angstroms, while angles + are expected to be in degrees by default, but can be + passed directly as radians as well. + + Parameters + ---------- + coords : torch.Tensor + Coordinates of atoms; expects shape [N, 3] for + N atoms. + a : float + Cell length a. + b : float + Cell length b. + c : float + Cell length c. + alpha : float + Lattice angle alpha; expected units depend on + the ``angles_are_degrees`` flag. + beta : float + Lattice angle beta; expected units depend on + the ``angles_are_degrees`` flag. + gamma : float + Lattice angle gamma; expected units depend on + the ``angles_are_degrees`` flag. + angles_are_degrees : bool, default True + Flag to designate whether the angles passed + to this function are in degrees. Defaults to + True, which then assumes the angles are in degrees + and conversion to radians are done within the function. + to_fractional : bool, default True + Specifies that the input coordinates are cartesian, + and that we are transforming them into fractional + coordinates. + + Returns + ------- + torch.Tensor + Fractional coordinate representation + """ + + def cot(x: float) -> float: + """cotangent of x""" + return -np.tan(x + np.pi / 2) + + def csc(x: float) -> float: + """cosecant of x""" + return 1 / np.sin(x) + + # convert to radians if angles are passed as degrees + if angles_are_degrees: + alpha = alpha * np.pi / 180.0 + beta = beta * np.pi / 180.0 + gamma = gamma * np.pi / 180.0 + + # This matrix is normally for fractional to cart. Implements the matrix found in + # https://en.wikipedia.org/wiki/Fractional_coordinates#General_transformations_between_fractional_and_Cartesian_coordinates + rotation = torch.tensor( + [ + [ + a + * np.sin(beta) + * np.sqrt( + 1 + - ( + (cot(alpha) * cot(beta)) + - (csc(alpha) * csc(beta) * np.cos(gamma)) + ) + ** 2.0 + ), + 0.0, + 0.0, + ], + [ + a * csc(alpha) * np.cos(gamma) - a * cot(alpha) * np.cos(beta), + b * np.sin(alpha), + 0.0, + ], + [a * np.cos(beta), b * np.cos(alpha), c], + ], + dtype=coords.dtype, + ) + if to_fractional: + # invert elements for the opposite conversion + rotation = torch.linalg.inv(rotation) + output = coords @ rotation + return output + + +def build_nearest_images(max_image_number: int) -> torch.Tensor: + """ + Utility function to exhaustively construct images based off + a maximum (absolute value) image number. + + These images can be used for tiling primarily for testing + and development. + + Parameters + ---------- + max_image_number : int + Maximum image number (absolute value) to consider. The + resulting tensor will span +/- this value. + + Returns + ------- + torch.Tensor + Float tensor containing image indices. + """ + indices = product( + range(-max_image_number, max_image_number), + range(-max_image_number, max_image_number), + range(-max_image_number, max_image_number), + ) + images = torch.FloatTensor(list(indices)) + return images