Skip to content

Commit

Permalink
refactor: moving lattice param calculations to before validator
Browse files Browse the repository at this point in the history
  • Loading branch information
laserkelvin committed Nov 27, 2024
1 parent 7e807ca commit 8e4481d
Showing 1 changed file with 14 additions and 8 deletions.
22 changes: 14 additions & 8 deletions matsciml/datasets/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -776,21 +776,27 @@ def _exception_wrapper(self, exception: Exception):
f"Data schema validation failed at sample {self.index}."
) from exception

@model_validator(mode="before")
@classmethod
def convert_lattice_and_parameters(cls, values: Any) -> Any:
lattice_params = values.get("lattice_parameters", None)
lattice_matrix = values.get("lattice_matrix", None)
if lattice_params is None and lattice_matrix is not None:
lattice_params = cell_to_cellpar(lattice_matrix)
values["lattice_parameters"] = lattice_params
if lattice_params is not None and lattice_matrix is None:
lattice_matrix = cellpar_to_cell(lattice_params)
values["lattice_matrix"] = lattice_matrix
return values

@model_validator(mode="after")
def coordinate_consistency(self) -> Self:
"""Sets fractional coordinates if parameters are available, and checks them"""
# convert lattice to lattice parameters if available
if self.lattice_matrix is not None and self.lattice_parameters is None:
self.lattice_parameters = cell_to_cellpar(self.lattice_matrix)
# convert lattice parameters to cell if available
if self.lattice_matrix is None and self.lattice_parameters is not None:
self.lattice_matrix = cellpar_to_cell(self.lattice_parameters)
# calculate fractional coordinates if we have lattice parameters
if self.frac_coords is None and self.lattice_parameters is not None:
self.frac_coords = cart_frac_conversion(
self.cart_coords, *self.lattice_parameters, to_fractional=True
)
if self.frac_coords is not None:
if isinstance(self.frac_coords, NDArray):
if self.frac_coords.shape != self.cart_coords.shape:
self._exception_wrapper(
ValueError(
Expand Down

0 comments on commit 8e4481d

Please sign in to comment.