From a9c3c41235ca4f3de48455a03e0ce8498302eec7 Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Wed, 27 Nov 2024 13:10:25 -0800 Subject: [PATCH] fix: correcting model output validation for scalar energies --- matsciml/common/types.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/matsciml/common/types.py b/matsciml/common/types.py index be9642fe..f91ede85 100644 --- a/matsciml/common/types.py +++ b/matsciml/common/types.py @@ -141,8 +141,12 @@ def standardize_total_energy( ``ValueError``. """ if isinstance(values, torch.Tensor): - # drop all redundant dimensions - values = values.squeeze() + if values.ndim > 1: + # drop all redundant dimensions + values = values.squeeze() + # we want to at least have a vector of energies + if values.ndim == 0: + values = values.unsqueeze(0) # last step is an assertion check for QA if values.ndim != 1: raise ValueError(