Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CS224W - Bag of Tricks for Node Classification with GNN - LogE Loss #9836

Open
wants to merge 11 commits into
base: master
Choose a base branch
from

Conversation

mattjhayes3
Copy link

Implement $\text{log-}\epsilon$ loss functions and modules

Part of #9831, as described in “Bag of Tricks for Node Classification with Graph Neural Networks”, this non-convex loss is thought to be less sensitive to outliers, providing a maximal gradient at decision boundaries, but still significant signal for all misclassified examples.

Details

  • Implements drop-in module and functional variants corresponding to nll, cross_entropy, binary_cross_entropy, and binary_cross_entropy_with_logits
  • Longer term pytorch might be the best place for these to live, but we think PyG might be a good place in the meantime as the paper shows it is more effective on GNNs than MLPs
  • PyG does not yet have any loss functions as far as we could tell, but torch_geometric.nn.functional seemed like a reasonable place for them to live. Happy to move to contrib as well.
  • Implemented as simple wrappers around pytorch losses for easy maintainability

Benchmarks

  • From benchmarks/citation using Colab's T4s we see it can bring small but statistically significant gains
    • Seems to work well with GAT more consistently than other models
  • It can however, in some settings cause big losses, e.g. it seems to work very poorly with SGC, and we recommend users trying more traditional losses as well to see what works best
    • Validation delta is not always correlated with with test delta, but in these cases it usually does not cause too big a regression
  • Typically slower due to the bit of extra computation, but can be faster due to early stopping
  • A selection of deltas which were statistically significant on test accuracy with at least 95% confidence are included below, expressed in direction loge - nll
    • For Arxiv, default GAT parameters from benchmarks/citation were used with a batch norm inserted, though these are surely suboptimal settings
    • Full results are available here
nll_command val_acc_abs_delta test_acc_abs_delta duration_rel_delta
gcn.py --dataset=CiteSeer 0.73% 0.76% 2.62%
gcn.py --dataset=PubMed 0.42% -0.47% 7.38%
gat.py --dataset=Cora 0.19% -0.21% 2.81%
gat.py --dataset=CiteSeer 0.86% 0.70% -6.07%
gat.py --dataset=PubMed --lr=0.01 --output_heads=8 --weight_decay=0.001 0.40% 0.18% 2.32%
gat.py --batch_norm --dataset=Arxiv --no_normalize_features --runs=20 0.73% 0.67% -1.54%
cheb.py --dataset=Cora --num_hops=3 0.65% 0.66% -4.62%
arma.py --dataset=Cora --num_layers=1 --num_stacks=2 --shared_weights -0.10% 0.19% -1.23%
arma.py --dataset=CiteSeer --num_layers=1 --num_stacks=3 --shared_weights 0.50% 0.61% 2.76%
sgc.py --K=3 --dataset=Cora --weight_decay=0.0005 -13.69% -12.37% -27.50%
sgc.py --K=2 --dataset=PubMed --weight_decay=0.0005 -13.00% -15.80% 10.01%
appnp.py --alpha=0.1 --dataset=Cora -0.38% -0.28% -3.66%
appnp.py --alpha=0.1 --dataset=CiteSeer 0.50% 0.53% 0.69%

@mattjhayes3 mattjhayes3 marked this pull request as ready for review December 10, 2024 02:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants