Updates: We have added the script for model visualization (Figure 4 in our paper)!
The repository implements the Structure-Aware Transformer (SAT) in Pytorch Geometric described in the following paper
Dexiong Chen*, Leslie O'Bray*, and Karsten Borgwardt. Structure-Aware Transformer for Graph Representation Learning. ICML 2022.
*Equal contribution
TL;DR: A class of simple and flexible graph transformers built upon a new self-attention mechanism, which incorporates structural information into the original self-attention by extracting a subgraph representation rooted at each node before computing the attention. Our structure-aware framework can leverage any existing GNN to extract the subgraph representation and systematically improve the peroformance relative to the base GNN.
Please use the following to cite our work:
@InProceedings{Chen22a,
author = {Dexiong Chen and Leslie O'Bray and Karsten Borgwardt},
title = {Structure-Aware Transformer for Graph Representation Learning},
year = {2022},
booktitle = {Proceedings of the 39th International Conference on Machine Learning~(ICML)},
series = {Proceedings of Machine Learning Research}
}
The SAT architecture compared with the vanilla transformer architecture is shown above. We make the self-attention calculation in each transformer layer structure-aware by leveraging structure-aware node embeddings. We generate these embeddings using a structure extractor (for example, any GNN) on the
The figure above shows the two example structure extractors used in our paper (
Below you can find a quick-start example on the ZINC dataset, see ./experiments/train_zinc.py
for more details.
click to see the example:
import torch
from torch_geometric import datasets
from torch_geometric.loader import DataLoader
from sat.data import GraphDataset
from sat import GraphTransformer
# Load the ZINC dataset using our wrapper GraphDataset,
# which automatically creates the fully connected graph.
# For datasets with large graph, we recommend setting return_complete_index=False
# leading to faster computation
dset = datasets.ZINC('./datasets/ZINC', subset=True, split='train')
dset = GraphDataset(dset)
# Create a PyG data loader
train_loader = DataLoader(dset, batch_size=16, shuffle=True)
# Create a SAT model
dim_hidden = 16
gnn_type = 'gcn' # use GCN as the structure extractor
k_hop = 2 # use a 2-layer GCN
model = GraphTransformer(
in_size=28, # number of node labels for ZINC
num_class=1, # regression task
d_model=dim_hidden,
dim_feedforward=2 * dim_hidden,
num_layers=2,
batch_norm=True,
gnn_type='gcn', # use GCN as the structure extractor
use_edge_attr=True,
num_edge_features=4, # number of edge labels
edge_dim=dim_hidden,
k_hop=k_hop,
se='gnn', # we use the k-subtree structure extractor
global_pool='add'
)
for data in train_loader:
output = model(data) # batch_size x 1
break
The dependencies are managed by miniconda
python=3.9
numpy
scipy
pytorch=1.9.1
pytorch-geometric=2.0.2
einops
ogb
Once you have activated the environment and installed all dependencies, run:
source s
Datasets will be downloaded via Pytorch geometric and OGB package.
All our experimental scripts are in the folder experiments
. So to start with, after having run source s
, run cd experiments
. The hyperparameters used below are selected as optimal
Train a k-subtree SAT with PNA:
python train_zinc.py --abs-pe rw --se gnn --gnn-type pna2 --dropout 0.3 --k-hop 3 --use-edge-attr
Train a k-subgraph SAT with PNA
python train_zinc.py --abs-pe rw --se khopgnn --gnn-type pna2 --dropout 0.2 --k-hop 3 --use-edge-attr
Train a k-subtree SAT on PATTERN:
python train_SBMs.py --dataset PATTERN --weight-class --abs-pe rw --abs-pe-dim 7 --se gnn --gnn-type pna3 --dropout 0.2 --k-hop 3 --num-layers 6 --lr 0.0003
and on CLUSTER:
python train_SBMs.py --dataset CLUSTER --weight-class --abs-pe rw --abs-pe-dim 3 --se gnn --gnn-type pna2 --dropout 0.4 --k-hop 3 --num-layers 16 --dim-hidden 48 --lr 0.0005
--gnn-type
can be gcn
, gine
or pna
, where pna
obtains the best performance.
# Train SAT on OGBG-PPA
python train_ppa.py --gnn-type gcn --use-edge-attr
# Train SAT on OGBG-CODE2
python train_code2.py --gnn-type gcn --use-edge-attr
We showcase here how to visualize the attention weights of the [CLS] node learned by SAT and vanilla Transformer with the random walk positional encoding. We have provided the pre-trained models on the Mutagenecity dataset. To visualize the pre-trained models, you need to install the networkx
and matplotlib
packages, then run:
python model_visu.py --graph-idx 2003
This will generate the following image, the same as the Figure 4 in our paper: