From 8a076f0beec7d195d553e1e5b58eaa66c6344f3d Mon Sep 17 00:00:00 2001 From: Leon van Bokhorst Date: Tue, 12 Nov 2024 05:30:50 +0100 Subject: [PATCH] Update requirements.txt and GATConv in gat_social_network.py --- requirements.txt | 12 ++---------- src/gat_social_network.py | 9 +++++---- 2 files changed, 7 insertions(+), 14 deletions(-) diff --git a/requirements.txt b/requirements.txt index c96c9ea..cf6dd12 100644 --- a/requirements.txt +++ b/requirements.txt @@ -25,16 +25,8 @@ seaborn pydot pygraphviz -# PyTorch ecosystem - install these first ---find-links https://download.pytorch.org/whl/torch_stable.html +# PyTorch ecosystem torch torchvision torchaudio - -# PyTorch Geometric dependencies - install these after PyTorch -torch-geometric -torch-scatter -torch-sparse -torch-cluster -torch-spline-conv -torch-geometric-temporal \ No newline at end of file +torch-geometric \ No newline at end of file diff --git a/src/gat_social_network.py b/src/gat_social_network.py index b56b004..5a1f9ee 100644 --- a/src/gat_social_network.py +++ b/src/gat_social_network.py @@ -5,7 +5,7 @@ import networkx as nx import torch import torch.nn as nn -from torch_geometric.nn import GATConv +from torch_geometric.nn import GATConv # Changed from GCNConv from torch_geometric.data import Data import matplotlib.pyplot as plt @@ -78,10 +78,11 @@ def __init__( # GAT layer: Each head learns different aspects of relationships # Output dim per head is hidden_dim // num_heads to maintain constant total dimensions self.gat_layer = GATConv( - hidden_dim, - hidden_dim // num_heads, + in_channels=hidden_dim, + out_channels=hidden_dim // num_heads, heads=num_heads, - concat=True, # Concatenate outputs from different heads + concat=True, + dropout=0.0 # Optional: Add dropout for regularization ).to(device) def add_agent(self, agent: AgentNode) -> None: