Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CS224W - Bag of Tricks for Node Classification with GNN - Non interactive GAT #9832

Open
wants to merge 10 commits into
base: master
Choose a base branch
from

Conversation

mattjhayes3
Copy link

@mattjhayes3 mattjhayes3 commented Dec 9, 2024

Add interactive_attn parameter to GATConv and GATv2Conv.

Part of #9831, this allows “Non-interactive” attention as described in “Bag of Tricks for Node Classification with Graph Neural Networks”, where target node features are not used for computing attention coefficients.

Details

  • There is already a subtle way to achieve non-interactive attention in GATConv by passing (x, None) to forward, however this method
    • Doesn’t work with residuals in the bipartite case
    • Isn't implemented for GATv2Conv
    • Adds complexity in hyperparameter search in comparison to passing in kwargs
    • Can’t be used except in first layer with existing models.GAT
    • Is not documented and not very readable
  • We currently include some documentation and comment changes that are only tangentially related. Happy to revert them to focus the change, but we believe it makes the workings of the module a little clearer.
    • Tuple in_channels can be used with non-bipartite graphs. E.g. users might want a separate set of features when a node is a source vs target.
    • Swap s and t in the notation as i and j are respectively the target and source nodes with the default flow. This is also more consistent with GATConv’s lin_src and lin_dst. Similarly, according to the message passing documentation, $e_{j, i}$ is the indexing used for edge features with the default flow, though the papers use the opposite notation.
  • We could slightly reduce changes in GATConv.forward by reusing the tuple in_channels path when iteractive_attn=False and in_channels is an integer, but this could be confusing

Benchmarks

The paper does not provide any metrics or performance references. benchmarks/citation reveals non-interactive is faster by ~2-16% on Colab's T4s with typically no effect on performance. It may be slightly worse on Cora v1+no_random_splits and v2+random_splits but slightly better on PubMed v1+random_splits.

Deltas below are expressed in the direction interactive - non_interactive. For Arxiv we used the default hyperparameters from benchmarks/citation/gat.py with a batch norm inserted, and for v2 the same as v1, though these are surely suboptimal settings.

interactive_command val_acc_abs_delta val_acc_pval test_acc_abs_delta test_acc_pval duration_rel_delta duration_pval
gat.py --dataset=Cora 0.14% 12.31% 0.11% 6.54% 3.80% 0.08%
gat.py --dataset=Cora --v2 0.05% 58.70% 0.01% 82.66% 7.55% 0.00%
gat.py --dataset=Cora --random_splits 0.01% 96.71% -0.16% 45.78% 2.77% 0.75%
gat.py --dataset=Cora --random_splits --v2 -0.09% 76.30% 0.44% 4.97% 6.15% 0.00%
gat.py --dataset=CiteSeer -0.25% 2.00% 0.08% 26.67% 2.04% 1.96%
gat.py --dataset=CiteSeer --v2 0.12% 27.36% -0.01% 86.19% 13.58% 0.00%
gat.py --dataset=CiteSeer --random_splits 0.18% 56.38% -0.15% 58.17% 4.67% 0.00%
gat.py --dataset=CiteSeer --random_splits --v2 0.18% 57.16% -0.13% 61.36% 13.28% 0.00%
gat.py --dataset=PubMed --lr=0.01 --output_heads=8 --weight_decay=0.001 -0.04% 64.61% 0.00% 92.15% 7.02% 0.00%
gat.py --dataset=PubMed --lr=0.01 --output_heads=8 --v2 --weight_decay=0.001 0.03% 70.07% 0.00% 94.62% 15.79% 0.00%
gat.py --dataset=PubMed --lr=0.01 --output_heads=8 --random_splits --weight_decay=0.001 -0.31% 45.31% -0.59% 9.35% 7.63% 0.00%
gat.py --dataset=PubMed --lr=0.01 --output_heads=8 --random_splits --v2 --weight_decay=0.001 0.14% 71.39% 0.16% 64.16% 15.54% 0.00%
gat.py --batch_norm --dataset=Arxiv --no_normalize_features --runs=20 -0.00% 94.04% -0.08% 43.23% 5.65% 0.00%
gat.py --batch_norm --dataset=Arxiv --no_normalize_features --runs=20 --v2 0.02% 72.76% -0.10% 33.38% 12.63% 0.00%

Full metrics: [interactive] [non-interactive]

@mattjhayes3 mattjhayes3 marked this pull request as ready for review December 9, 2024 06:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants