diff --git a/src/matgl/models/_tensornet.py b/src/matgl/models/_tensornet.py index 66d8131a..bfed4610 100644 --- a/src/matgl/models/_tensornet.py +++ b/src/matgl/models/_tensornet.py @@ -234,12 +234,10 @@ def forward(self, g: dgl.DGLGraph, state_attr: torch.Tensor | None = None, **kwa g.edata["bond_vec"] = bond_vec.to(g.device) g.edata["bond_dist"] = bond_dist.to(g.device) - # This asserts convinces TorchScript that edge_vec is a Tensor and not an Optional[Tensor] - # Expand distances with radial basis functions edge_attr = self.bond_expansion(g.edata["bond_dist"]) g.edata["edge_attr"] = edge_attr - # Embedding from edge-wise tensors to node-wise tensors + # Embedding layer X, edge_feat, state_feat = self.tensor_embedding(g, state_attr) # Interaction layers for layer in self.layers: