Skip to content

Commit

Permalink
fix & test: added lattice conversion tests and fixes to make it pass
Browse files Browse the repository at this point in the history
  • Loading branch information
laserkelvin committed Nov 27, 2024
1 parent 8e4481d commit e978fbe
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 4 deletions.
2 changes: 1 addition & 1 deletion matsciml/datasets/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,7 +803,7 @@ def coordinate_consistency(self) -> Self:
"Fractional coordinate dimensions do not match cartesians."
)
)
if min(self.frac_coords) < 0.0 or max(self.frac_coords) > 1.0:
if self.frac_coords.min() < 0.0 or self.frac_coords.max() > 1.0:
self._exception_wrapper(
ValueError("Fractional coordinates are outside of [0, 1].")
)
Expand Down
22 changes: 22 additions & 0 deletions matsciml/datasets/tests/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pytest
import numpy as np
import torch
from ase.geometry import cell_to_cellpar

from matsciml.datasets import schema
from matsciml.datasets.transforms import PeriodicPropertiesTransform
Expand Down Expand Up @@ -175,3 +176,24 @@ def test_data_sample_fail_coord_shape(num_atoms, array_lib):
pbc=pbc,
datatype="SCFCycle",
)


def test_lattice_param_to_matrix_consistency():
"""Make sure that lattice parameters map to matrix correctly during validation"""
coords = np.random.rand(5, 3)
numbers = np.random.randint(1, 100, (5))
data = schema.DataSampleSchema(
index=0,
num_atoms=5,
cart_coords=coords,
atomic_numbers=numbers,
pbc={"x": True, "y": True, "z": True},
datatype="OptimizationCycle",
lattice_parameters=[5.0, 5.0, 5.0, 90.0, 90.0, 90.0],
)
assert data.frac_coords is not None
assert data.lattice_matrix is not None
reconverted = cell_to_cellpar(data.lattice_matrix)
assert np.allclose(reconverted, data.lattice_parameters)
exact = np.array([[5.0, 0.0, 0.0], [0.0, 5.0, 0.0], [0.0, 0.0, 5.0]])
assert np.allclose(exact, data.lattice_matrix)
8 changes: 5 additions & 3 deletions matsciml/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -964,7 +964,7 @@ def csc(x: float) -> float:

# This matrix is normally for fractional to cart. Implements the matrix found in
# https://en.wikipedia.org/wiki/Fractional_coordinates#General_transformations_between_fractional_and_Cartesian_coordinates
rotation = torch.tensor(
rotation = np.array(
[
[
a
Expand All @@ -987,11 +987,13 @@ def csc(x: float) -> float:
],
[a * np.cos(beta), b * np.cos(alpha), c],
],
dtype=coords.dtype,
)
if to_fractional:
# invert elements for the opposite conversion
rotation = torch.linalg.inv(rotation)
rotation = np.linalg.inv(rotation)
# if coords are already torch, cast as a tensor so we can matmul
if isinstance(coords, torch.Tensor):
rotation = torch.from_numpy(rotation).to(coords.type)
output = coords @ rotation
return output

Expand Down

0 comments on commit e978fbe

Please sign in to comment.