diff --git a/src/matgl/graph/data.py b/src/matgl/graph/data.py index 7ffb2dfe..93836e68 100644 --- a/src/matgl/graph/data.py +++ b/src/matgl/graph/data.py @@ -189,7 +189,7 @@ def process(self): if self.graph_labels is not None: state_attrs = torch.tensor(self.graph_labels).long() else: - state_attrs = torch.tensor(np.array(state_attrs)) + state_attrs = torch.tensor(np.array(state_attrs), dtype=matgl.float_th) if self.clear_processed: del self.structures diff --git a/src/matgl/models/_tensornet.py b/src/matgl/models/_tensornet.py index 9050dbbb..66d8131a 100644 --- a/src/matgl/models/_tensornet.py +++ b/src/matgl/models/_tensornet.py @@ -180,8 +180,9 @@ def __init__( ) self.out_norm = nn.LayerNorm(3 * units, dtype=dtype) + self.linear = nn.Linear(3 * units, units, dtype=dtype) if is_intensive: - input_feats = 3 * units if field == "node_feat" else units + input_feats = units if readout_type == "set2set": self.readout = Set2SetReadOut( in_feats=input_feats, n_iters=niters_set2set, n_layers=nlayers_set2set, field=field @@ -203,7 +204,7 @@ def __init__( if task_type == "classification": raise ValueError("Classification task cannot be extensive.") self.final_layer = WeightedReadOut( - in_feats=3 * units, + in_feats=units, dims=[units, units], num_targets=ntargets, # type: ignore ) @@ -247,6 +248,7 @@ def forward(self, g: dgl.DGLGraph, state_attr: torch.Tensor | None = None, **kwa x = torch.cat((tensor_norm(scalars), tensor_norm(skew_metrices), tensor_norm(traceless_tensors)), dim=-1) x = self.out_norm(x) + x = self.linear(x) g.ndata["node_feat"] = x if self.is_intensive: