Skip to content

Commit

Permalink
Merge pull request #42 from leonvanbokhorst/fix-torch-geometric-gatconv
Browse files Browse the repository at this point in the history
Update requirements.txt and GATConv in gat_social_network.py
  • Loading branch information
leonvanbokhorst authored Nov 12, 2024
2 parents 0d6876f + 8a076f0 commit cd9d230
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 14 deletions.
12 changes: 2 additions & 10 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,8 @@ seaborn
pydot
pygraphviz

# PyTorch ecosystem - install these first
--find-links https://download.pytorch.org/whl/torch_stable.html
# PyTorch ecosystem
torch
torchvision
torchaudio

# PyTorch Geometric dependencies - install these after PyTorch
torch-geometric
torch-scatter
torch-sparse
torch-cluster
torch-spline-conv
torch-geometric-temporal
torch-geometric
9 changes: 5 additions & 4 deletions src/gat_social_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import networkx as nx
import torch
import torch.nn as nn
from torch_geometric.nn import GATConv
from torch_geometric.nn import GATConv # Changed from GCNConv
from torch_geometric.data import Data
import matplotlib.pyplot as plt

Expand Down Expand Up @@ -78,10 +78,11 @@ def __init__(
# GAT layer: Each head learns different aspects of relationships
# Output dim per head is hidden_dim // num_heads to maintain constant total dimensions
self.gat_layer = GATConv(
hidden_dim,
hidden_dim // num_heads,
in_channels=hidden_dim,
out_channels=hidden_dim // num_heads,
heads=num_heads,
concat=True, # Concatenate outputs from different heads
concat=True,
dropout=0.0 # Optional: Add dropout for regularization
).to(device)

def add_agent(self, agent: AgentNode) -> None:
Expand Down

0 comments on commit cd9d230

Please sign in to comment.