Skip to content

Commit

Permalink
Fix edge index manipulation to make training work again
Browse files Browse the repository at this point in the history
  • Loading branch information
joeloskarsson committed Nov 12, 2024
1 parent ce204b6 commit 22bfe65
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 4 deletions.
4 changes: 2 additions & 2 deletions neural_lam/interaction_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ def __init__(
# any edge_index used here must start sender and rec. nodes at index 0
# Store number of receiver nodes according to edge_index
self.num_rec = edge_index[1].max() + 1
edge_index[0] = (
edge_index[0] + self.num_rec
edge_index = torch.stack(
(edge_index[0] + self.num_rec, edge_index[1]), dim=0
) # Make sender indices after rec
self.register_buffer("edge_index", edge_index, persistent=False)

Expand Down
2 changes: 0 additions & 2 deletions neural_lam/models/base_graph_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@ def __init__(self, args):
super().__init__(args)

# Load graph with static features
# NOTE: (IMPORTANT!) mesh nodes MUST have the first
# num_mesh_nodes indices,
self.hierarchical, graph_ldict = utils.load_graph(args.graph)
for name, attr_value in graph_ldict.items():
# Make BufferLists module members and register tensors as buffers
Expand Down

0 comments on commit 22bfe65

Please sign in to comment.