From cc13b15da8bf35790ec35b88ad6600af964003be Mon Sep 17 00:00:00 2001 From: Tsz Wai Ko <47970742+kenko911@users.noreply.github.com> Date: Wed, 2 Oct 2024 22:17:16 -0700 Subject: [PATCH] adding back .cpu() before .numpy() in get_feature_matrix.py Signed-off-by: Tsz Wai Ko <47970742+kenko911@users.noreply.github.com> --- src/matgl/layers/_atom_ref.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/matgl/layers/_atom_ref.py b/src/matgl/layers/_atom_ref.py index cc77de79..eb717cd9 100644 --- a/src/matgl/layers/_atom_ref.py +++ b/src/matgl/layers/_atom_ref.py @@ -41,7 +41,7 @@ def get_feature_matrix(self, graphs: list[dgl.DGLGraph]) -> torch.Tensor: for i, graph in enumerate(graphs): atomic_numbers = graph.ndata["node_type"] features[i] = torch.bincount(atomic_numbers, minlength=self.max_z) - return features.numpy() + return features.cpu().numpy() def fit(self, graphs: list[dgl.DGLGraph], properties: torch.Tensor) -> None: """Fit the elemental reference values for the properties.