diff --git a/src/matgl/layers/_atom_ref.py b/src/matgl/layers/_atom_ref.py index cc77de79..eb717cd9 100644 --- a/src/matgl/layers/_atom_ref.py +++ b/src/matgl/layers/_atom_ref.py @@ -41,7 +41,7 @@ def get_feature_matrix(self, graphs: list[dgl.DGLGraph]) -> torch.Tensor: for i, graph in enumerate(graphs): atomic_numbers = graph.ndata["node_type"] features[i] = torch.bincount(atomic_numbers, minlength=self.max_z) - return features.numpy() + return features.cpu().numpy() def fit(self, graphs: list[dgl.DGLGraph], properties: torch.Tensor) -> None: """Fit the elemental reference values for the properties.