Skip to content

Commit

Permalink
fix: correcting model output validation for scalar energies
Browse files Browse the repository at this point in the history
  • Loading branch information
laserkelvin committed Nov 27, 2024
1 parent 7db4282 commit a9c3c41
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions matsciml/common/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,12 @@ def standardize_total_energy(
``ValueError``.
"""
if isinstance(values, torch.Tensor):
# drop all redundant dimensions
values = values.squeeze()
if values.ndim > 1:
# drop all redundant dimensions
values = values.squeeze()
# we want to at least have a vector of energies
if values.ndim == 0:
values = values.unsqueeze(0)
# last step is an assertion check for QA
if values.ndim != 1:
raise ValueError(
Expand Down

0 comments on commit a9c3c41

Please sign in to comment.