diff --git a/CHANGELOG.md b/CHANGELOG.md index 4a18280634a4..b73fba470704 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added `SGFormer` model and example ([#9904](https://github.com/pyg-team/pytorch_geometric/pull/9904)) - Consolidate Cugraph examples into ogbn_train_cugraph.py and ogbn_train_cugraph_multigpu.py for ogbn-arxiv, ogbn-products and ogbn-papers100M ([#9953](https://github.com/pyg-team/pytorch_geometric/pull/9953)) - Added `InstructMol` dataset ([#9975](https://github.com/pyg-team/pytorch_geometric/pull/9975)) - Added support for weighted `LinkPredRecall` metric ([#9947](https://github.com/pyg-team/pytorch_geometric/pull/9947)) diff --git a/examples/README.md b/examples/README.md index b7a44510903b..2eee1b94fb33 100644 --- a/examples/README.md +++ b/examples/README.md @@ -20,3 +20,7 @@ For examples on scaling PyG up via multi-GPUs, see the examples under [`examples For examples on working with heterogeneous data, see the examples under [`examples/hetero`](./hetero). For examples on co-training LLMs with GNNs, see the examples under [`examples/llm`](./llm). + +For an example on training a GraphTransformer see [`examples/sgformer.py`](examples/sgformer.py) + +- [SGFormer Paper](https://arxiv.org/pdf/2306.10759) diff --git a/examples/sgformer.py b/examples/sgformer.py new file mode 100644 index 000000000000..b0c7bd6cbf60 --- /dev/null +++ b/examples/sgformer.py @@ -0,0 +1,306 @@ +"""This example run SGFormer model on ogbn-arxiv dataset. +Original Paper: https://arxiv.org/abs/2306.10759 +"SGFormer: Simplifying and Empowering Transformers for +Large-Graph Representations". +""" +import argparse + +import torch +import torch.nn.functional as F +from ogb.nodeproppred import Evaluator, PygNodePropPredDataset + +from torch_geometric import seed_everything +from torch_geometric.nn.attention import SGFormerAttention +from torch_geometric.nn.conv import GCNConv +from torch_geometric.utils import ( + add_self_loops, + remove_self_loops, + to_undirected, +) + + +class GraphModule(torch.nn.Module): + def __init__( + self, + in_channels, + hidden_channels, + num_layers=2, + dropout=0.5, + ): + super().__init__() + + self.convs = torch.nn.ModuleList() + self.fcs = torch.nn.ModuleList() + self.fcs.append(torch.nn.Linear(in_channels, hidden_channels)) + + self.bns = torch.nn.ModuleList() + self.bns.append(torch.nn.BatchNorm1d(hidden_channels)) + for _ in range(num_layers): + self.convs.append(GCNConv(hidden_channels, hidden_channels)) + self.bns.append(torch.nn.BatchNorm1d(hidden_channels)) + + self.dropout = dropout + self.activation = F.relu + + def reset_parameters(self): + for conv in self.convs: + conv.reset_parameters() + for bn in self.bns: + bn.reset_parameters() + for fc in self.fcs: + fc.reset_parameters() + + def forward(self, x, edge_index): + x = self.fcs[0](x) + x = self.bns[0](x) + x = self.activation(x) + x = F.dropout(x, p=self.dropout, training=self.training) + last_x = x + + for i, conv in enumerate(self.convs): + x = conv(x, edge_index) + x = self.bns[i + 1](x) + x = self.activation(x) + x = F.dropout(x, p=self.dropout, training=self.training) + x = x + last_x + return x + + +class SGModule(torch.nn.Module): + def __init__( + self, + in_channels, + hidden_channels, + num_layers=2, + num_heads=1, + dropout=0.5, + ): + super().__init__() + + self.attns = torch.nn.ModuleList() + self.fcs = torch.nn.ModuleList() + self.fcs.append(torch.nn.Linear(in_channels, hidden_channels)) + self.bns = torch.nn.ModuleList() + self.bns.append(torch.nn.LayerNorm(hidden_channels)) + for _ in range(num_layers): + self.attns.append( + SGFormerAttention(hidden_channels, num_heads, hidden_channels)) + self.bns.append(torch.nn.LayerNorm(hidden_channels)) + + self.dropout = dropout + self.activation = F.relu + + def reset_parameters(self): + for attn in self.attns: + attn.reset_parameters() + for bn in self.bns: + bn.reset_parameters() + for fc in self.fcs: + fc.reset_parameters() + + def forward(self, x): + layer_ = [] + + # input MLP layer + x = self.fcs[0](x) + x = self.bns[0](x) + x = self.activation(x) + x = F.dropout(x, p=self.dropout, training=self.training) + + # store as residual link + layer_.append(x) + + for i, attn in enumerate(self.attns): + x = attn(x) + x = (x + layer_[i]) / 2. + x = self.bns[i + 1](x) + x = self.activation(x) + x = F.dropout(x, p=self.dropout, training=self.training) + layer_.append(x) + + return x + + +class Net(torch.nn.Module): + def __init__( + self, + in_channels, + hidden_channels, + out_channels, + trans_num_layers=1, + trans_num_heads=1, + trans_dropout=0.5, + gnn_num_layers=1, + gnn_dropout=0.5, + graph_weight=0.8, + aggregate='add', + ): + super().__init__() + self.trans_conv = SGModule( + in_channels, + hidden_channels, + trans_num_layers, + trans_num_heads, + trans_dropout, + ) + self.graph_conv = GraphModule( + in_channels, + hidden_channels, + gnn_num_layers, + gnn_dropout, + ) + self.graph_weight = graph_weight + + self.aggregate = aggregate + + if aggregate == 'add': + self.fc = torch.torch.nn.Linear(hidden_channels, out_channels) + elif aggregate == 'cat': + self.fc = torch.torch.nn.Linear(2 * hidden_channels, out_channels) + else: + raise ValueError(f'Invalid aggregate type:{aggregate}') + + self.params1 = list(self.trans_conv.parameters()) + self.params2 = list(self.graph_conv.parameters()) + self.params2.extend(list(self.fc.parameters())) + + def forward(self, x, edge_index): + x1 = self.trans_conv(x) + x2 = self.graph_conv(x, edge_index) + if self.aggregate == 'add': + x = self.graph_weight * x2 + (1 - self.graph_weight) * x1 + else: + x = torch.cat((x1, x2), dim=1) + x = self.fc(x) + return x + + +def main(args): + # load dataset ============================================================ + dataset = PygNodePropPredDataset(name=args.dataset, root=args.data_dir) + g = dataset[0] + # get the splits + split_idx = dataset.get_idx_split() + # basic information of datasets + n = g.num_nodes + e = g.num_edges + # infer the number of classes for non one-hot and one-hot labels + c = g.y.max().item() + 1 + d = g.x.size(1) + + print(f'dataset {args.dataset} | #nodes {n} | #edges {e} ' + f'| #node feats {d} | #classes {c}') + + g.edge_index = to_undirected(g.edge_index) + g.edge_index, _ = remove_self_loops(g.edge_index) + g.edge_index, _ = add_self_loops(g.edge_index, num_nodes=n) + + # load device ========================================================== + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + g = g.to(device) + + # define model =================================================== + model = Net( + d, + args.hidden_channels, + c, + graph_weight=args.graph_weight, + aggregate=args.aggregate, + trans_num_layers=args.trans_num_layers, + trans_dropout=args.trans_dropout, + trans_num_heads=args.trans_num_heads, + gnn_num_layers=args.gnn_num_layers, + gnn_dropout=args.gnn_dropout, + ).to(device) + criterion = torch.torch.nn.NLLLoss() + optimizer = torch.optim.Adam([ + { + 'params': model.params1, + 'weight_decay': args.trans_weight_decay + }, + { + 'params': model.params2, + 'weight_decay': args.gnn_weight_decay + }, + ], lr=args.lr) + + # define evaluator =============================================== + evaluator = Evaluator(name=args.dataset) + + def evaluate(out, split): + assert split in ['train', 'valid', 'test'] + y_true = g.y.cpu() + y_pred = out.argmax(dim=-1, keepdim=True) + return evaluator.eval({ + 'y_true': y_true[split_idx[split]], + 'y_pred': y_pred[split_idx[split]], + })['acc'] + + # training loop ================================================== + train_idx = split_idx['train'].to(device) + best_val, best_test = 0, 0 + for epoch in range(args.epochs): + model.train() + optimizer.zero_grad() + out = model(g.x, g.edge_index) + out = F.log_softmax(out, dim=-1) + loss = criterion(out[train_idx], g.y[train_idx].view(-1)) + loss.backward() + optimizer.step() + + if epoch % args.eval_step == 0: + model.eval() + out = model(g.x, g.edge_index) + out = F.log_softmax(out, dim=-1) + train_acc = evaluate(out, 'train') + valid_acc = evaluate(out, 'valid') + test_acc = evaluate(out, 'test') + if valid_acc > best_val: + best_val = valid_acc + best_test = test_acc + print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, ' + f'Train: {100 * train_acc:.2f}%, ' + f'Valid: {100 * valid_acc:.2f}%, ' + f'Test: {100 * test_acc:.2f}%') + print(f'Best Test Acc: {100 * best_test:.2f}%') + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='SGFormer Example') + # dataset and evaluation + parser.add_argument('--dataset', type=str, default='ogbn-arxiv') + parser.add_argument('--data_dir', type=str, default='./data/ogb') + parser.add_argument('--epochs', type=int, default=1000) + + # gnn branch + parser.add_argument('--hidden_channels', type=int, default=256) + parser.add_argument('--aggregate', type=str, default='add', + help='aggregate type, add or cat.') + parser.add_argument('--graph_weight', type=float, default=0.5, + help='graph weight.') + + parser.add_argument('--gnn_num_layers', type=int, default=3, + help='number of layers for GNN') + parser.add_argument('--gnn_dropout', type=float, default=0.5) + parser.add_argument('--gnn_weight_decay', type=float, default=0.) + + # all-pair attention (Transformer) branch + parser.add_argument('--trans_num_heads', type=int, default=1, + help='number of heads for attention') + parser.add_argument('--trans_num_layers', type=int, default=2, + help='number of layers for all-pair attention.') + parser.add_argument('--trans_dropout', type=float, default=0.5) + parser.add_argument('--trans_weight_decay', type=float, default=0.) + + # training + parser.add_argument('--lr', type=float, default=0.001) + parser.add_argument('--batch_size', type=int, default=10000, + help='mini batch training for large graphs') + + # display and utility + parser.add_argument('--eval_step', type=int, default=9, + help='how often to evaluate') + args = parser.parse_args() + print(args) + seed_everything(123) + main(args) diff --git a/test/nn/attention/test_sgformer.py b/test/nn/attention/test_sgformer.py new file mode 100644 index 000000000000..7b9c1ea499da --- /dev/null +++ b/test/nn/attention/test_sgformer.py @@ -0,0 +1,12 @@ +import torch + +from torch_geometric.nn.attention import SGFormerAttention + + +def test_sgformer_attention(): + x = torch.randn(1, 4, 16) + attn = SGFormerAttention(channels=16, heads=1) + out = attn(x) + assert out.shape == (4, 64) + assert str(attn) == ('SGFormerAttention(heads=1, ' + 'head_channels=64)') diff --git a/torch_geometric/nn/attention/__init__.py b/torch_geometric/nn/attention/__init__.py index 6b4064cd34b9..dd2e0a039674 100644 --- a/torch_geometric/nn/attention/__init__.py +++ b/torch_geometric/nn/attention/__init__.py @@ -1,7 +1,9 @@ from .performer import PerformerAttention from .qformer import QFormer +from .sgformer import SGFormerAttention __all__ = [ 'PerformerAttention', 'QFormer', + 'SGFormerAttention', ] diff --git a/torch_geometric/nn/attention/sgformer.py b/torch_geometric/nn/attention/sgformer.py new file mode 100644 index 000000000000..81d0dbd7e505 --- /dev/null +++ b/torch_geometric/nn/attention/sgformer.py @@ -0,0 +1,78 @@ +import torch +from torch import Tensor + + +class SGFormerAttention(torch.nn.Module): + r"""The simple global attention mechanism from the + `"SGFormer: Simplifying and Empowering Transformers for + Large-Graph Representations" + `_ paper. + + Args: + channels (int): Size of each input sample. + heads (int, optional): Number of parallel attention heads. + (default: :obj:`1.`) + head_channels (int, optional): Size of each attention head. + (default: :obj:`64.`) + qkv_bias (bool, optional): If specified, add bias to query, key + and value in the self attention. (default: :obj:`False`) + """ + def __init__( + self, + channels: int, + heads: int = 1, + head_channels: int = 64, + qkv_bias: bool = False, + ) -> None: + super().__init__() + assert channels % heads == 0 + assert heads == 1, 'The number of heads are fixed as 1.' + if head_channels is None: + head_channels = channels // heads + + self.heads = heads + self.head_channels = head_channels + + inner_channels = head_channels * heads + self.q = torch.nn.Linear(channels, inner_channels, bias=qkv_bias) + self.k = torch.nn.Linear(channels, inner_channels, bias=qkv_bias) + self.v = torch.nn.Linear(channels, inner_channels, bias=qkv_bias) + + def forward(self, x: Tensor) -> Tensor: + # feature transformation + qs = self.q(x).reshape(-1, self.heads, self.head_channels) + ks = self.k(x).reshape(-1, self.heads, self.head_channels) + vs = self.v(x).reshape(-1, self.heads, self.head_channels) + + # normalize input + qs = qs / torch.norm(qs, p=2) # [N, H, M] + ks = ks / torch.norm(ks, p=2) # [L, H, M] + N = qs.shape[0] + + # numerator + kvs = torch.einsum("lhm,lhd->hmd", ks, vs) + attention_num = torch.einsum("nhm,hmd->nhd", qs, kvs) # [N, H, D] + attention_num += N * vs + + # denominator + all_ones = torch.ones([ks.shape[0]]).to(ks.device) + ks_sum = torch.einsum("lhm,l->hm", ks, all_ones) + attention_normalizer = torch.einsum("nhm,hm->nh", qs, ks_sum) # [N, H] + + # attentive aggregated results + attention_normalizer = torch.unsqueeze( + attention_normalizer, len(attention_normalizer.shape)) # [N, H, 1] + attention_normalizer += torch.ones_like(attention_normalizer) * N + attn_output = attention_num / attention_normalizer # [N, H, D] + + return attn_output.mean(dim=1) + + def reset_parameters(self): + self.q.reset_parameters() + self.k.reset_parameters() + self.v.reset_parameters() + + def __repr__(self) -> str: + return (f'{self.__class__.__name__}(' + f'heads={self.heads}, ' + f'head_channels={self.head_channels})')