diff --git a/CHANGELOG.md b/CHANGELOG.md index 4e7038eae919..ccf2a5589ee6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added `SGFormer` model and example ([#9904](https://github.com/pyg-team/pytorch_geometric/pull/9904)) - Added `AveragePopularity` metric for link prediction ([#10022](https://github.com/pyg-team/pytorch_geometric/pull/10022)) - Added `Personalization` metric for link prediction ([#10015](https://github.com/pyg-team/pytorch_geometric/pull/10015)) - Added `HitRatio` metric for link prediction ([#10013](https://github.com/pyg-team/pytorch_geometric/pull/10013)) diff --git a/examples/README.md b/examples/README.md index b7a44510903b..06eb6b5e8719 100644 --- a/examples/README.md +++ b/examples/README.md @@ -10,6 +10,8 @@ For a simple link prediction example, see [`link_pred.py`](./link_pred.py). For examples on [Open Graph Benchmark](https://ogb.stanford.edu/) datasets, see the `ogbn_*.py` examples: - [`ogbn_train.py`](./ogbn_train.py) is an example for training a GNN on the large-scale `ogbn-papers100m` dataset, containing approximately ~1.6B edges or the medium scale `ogbn-products` dataset, ~62M edges. + - Uses SGFormer (a kind of GraphTransformer) by default. + - [SGFormer Paper](https://arxiv.org/pdf/2306.10759) - [`ogbn_proteins_deepgcn.py`](./ogbn_proteins_deepgcn.py) is an example to showcase how to train deep GNNs on the `ogbn-proteins` dataset. - [`ogbn_train_cugraph.py`](./ogbn_train_cugraph.py) shows how to accelerate the `ogbn_train.py` workflow using [CuGraph](https://github.com/rapidsai/cugraph). diff --git a/examples/multi_gpu/ogbn_train_cugraph.py b/examples/multi_gpu/ogbn_train_cugraph.py index a4d9692936ab..118ebb45650a 100644 --- a/examples/multi_gpu/ogbn_train_cugraph.py +++ b/examples/multi_gpu/ogbn_train_cugraph.py @@ -21,7 +21,12 @@ from torch.nn.parallel import DistributedDataParallel import torch_geometric -from torch_geometric.utils import to_undirected +from torch_geometric import seed_everything +from torch_geometric.utils import ( + add_self_loops, + remove_self_loops, + to_undirected, +) # Allow computation on objects that are larger than GPU memory # https://docs.rapids.ai/api/cudf/stable/developer_guide/library_design/#spilling-to-host-memory @@ -38,7 +43,7 @@ def arg_parse(): parser.add_argument( '--dataset', type=str, - default='ogbn-papers100M', + default='ogbn-arxiv', choices=['ogbn-papers100M', 'ogbn-products', 'ogbn-arxiv'], help='Dataset name.', ) @@ -51,35 +56,41 @@ def arg_parse(): parser.add_argument( "--dataset_subdir", type=str, - default="ogb-papers100M", + default="ogbn-arxiv", help="directory of dataset.", ) parser.add_argument('--hidden_channels', type=int, default=256) - parser.add_argument('--num_layers', type=int, default=2) + parser.add_argument('--num_layers', type=int, default=3) parser.add_argument('--lr', type=float, default=0.001) parser.add_argument('--wd', type=float, default=0.000) - parser.add_argument('-e', '--epochs', type=int, default=10) + parser.add_argument('-e', '--epochs', type=int, default=50) parser.add_argument('-b', '--batch_size', type=int, default=1024) parser.add_argument('--fan_out', type=int, default=10) parser.add_argument('--eval_steps', type=int, default=1000) parser.add_argument('--warmup_steps', type=int, default=20) + parser.add_argument('--dropout', type=float, default=0.5) parser.add_argument( '--use_directed_graph', action='store_true', help='Whether or not to use directed graph', ) + parser.add_argument( + '--add_self_loop', + action='store_true', + help='Whether or not to add self loop', + ) parser.add_argument( "--model", type=str, default='GCN', - choices=['SAGE', 'GAT', 'GCN'], - help="Model used for training, default GraphSAGE", + choices=['SAGE', 'GAT', 'GCN', 'SGFormer'], + help="Model used for training, default GCN", ) parser.add_argument( - "--num_gat_conv_heads", + "--num_heads", type=int, - default=4, - help="If using GATConv, number of attention heads to use", + default=1, + help="If using GATConv or GT, number of attention heads to use", ) parser.add_argument('--tempdir_root', type=str, default=None) parser.add_argument( @@ -310,7 +321,7 @@ def run_train(rank, args, data, world_size, cugraph_id, model, split_idx, if __name__ == '__main__': args = arg_parse() - + seed_everything(123) wall_clock_start = time.perf_counter() root = osp.join(args.dataset_dir, args.dataset_subdir) @@ -319,17 +330,20 @@ def run_train(rank, args, data, world_size, cugraph_id, model, split_idx, data = dataset[0] if not args.use_directed_graph: data.edge_index = to_undirected(data.edge_index, reduce="mean") + if args.add_self_loop: + data.edge_index, _ = remove_self_loops(data.edge_index) + data.edge_index, _ = add_self_loops(data.edge_index, + num_nodes=data.num_nodes) data.y = data.y.reshape(-1) + print(f"Training {args.dataset} with {args.model} model.") if args.model == "GAT": - print(f"Training {args.dataset} with GAT model.") model = torch_geometric.nn.models.GAT(dataset.num_features, args.hidden_channels, args.num_layers, dataset.num_classes, - heads=args.num_gat_conv_heads) + heads=args.num_heads) elif args.model == "GCN": - print(f"Training {args.dataset} with GCN model.") model = torch_geometric.nn.models.GCN( dataset.num_features, args.hidden_channels, @@ -337,15 +351,24 @@ def run_train(rank, args, data, world_size, cugraph_id, model, split_idx, dataset.num_classes, ) elif args.model == "SAGE": - print(f"Training {args.dataset} with GraphSAGE model.") model = torch_geometric.nn.models.GraphSAGE( dataset.num_features, args.hidden_channels, args.num_layers, dataset.num_classes, ) + elif args.model == 'SGFormer': + model = torch_geometric.nn.models.SGFormer( + in_channels=dataset.num_features, + hidden_channels=args.hidden_channels, + out_channels=dataset.num_classes, + trans_num_heads=args.num_heads, + trans_dropout=args.dropout, + gnn_num_layers=args.num_layers, + gnn_dropout=args.dropout, + ) else: - raise ValueError('Unsupported model type: {args.model}') + raise ValueError(f'Unsupported model type: {args.model}') print("Data =", data) if args.num_devices < 1: diff --git a/examples/ogbn_train.py b/examples/ogbn_train.py index d975b86f51d0..7b834f7a3150 100644 --- a/examples/ogbn_train.py +++ b/examples/ogbn_train.py @@ -9,17 +9,22 @@ from torch import Tensor from tqdm import tqdm +from torch_geometric import seed_everything from torch_geometric.loader import NeighborLoader -from torch_geometric.nn.models import GAT, GraphSAGE -from torch_geometric.utils import to_undirected +from torch_geometric.nn.models import GAT, GraphSAGE, SGFormer +from torch_geometric.utils import ( + add_self_loops, + remove_self_loops, + to_undirected, +) parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument( '--dataset', type=str, - default='ogbn-papers100M', - choices=['ogbn-papers100M', 'ogbn-products'], + default='ogbn-arxiv', + choices=['ogbn-papers100M', 'ogbn-products', 'ogbn-arxiv'], help='Dataset name.', ) parser.add_argument( @@ -29,13 +34,15 @@ help='Root directory of dataset.', ) parser.add_argument( - '--use_gat', - action='store_true', - help='Whether or not to use GAT model', + '--gnn_choice', + type=str, + default='sgformer', + choices=['gat', 'graphsage', 'sgformer'], + help='Model name.', ) parser.add_argument('-e', '--epochs', type=int, default=50) parser.add_argument('--num_layers', type=int, default=3) -parser.add_argument('--num_heads', type=int, default=2, +parser.add_argument('--num_heads', type=int, default=1, help='number of heads for GAT model.') parser.add_argument('-b', '--batch_size', type=int, default=1024) parser.add_argument('--num_workers', type=int, default=12) @@ -50,6 +57,11 @@ action='store_true', help='Whether or not to use directed graph', ) +parser.add_argument( + '--add_self_loop', + action='store_true', + help='Whether or not to add self loop', +) args = parser.parse_args() wall_clock_start = time.perf_counter() @@ -60,11 +72,9 @@ print('Consider upgrading RAM if an error occurs.') print('Estimated RAM Needed: ~390GB.') -if args.use_gat: - print(f'Training {args.dataset} with GAT model.') -else: - print(f'Training {args.dataset} with GraphSage model.') +print(f'Training {args.dataset} with {args.gnn_choice} model.') +seed_everything(123) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') num_epochs = args.epochs num_layers = args.num_layers @@ -79,6 +89,10 @@ if not args.use_directed_graph: data.edge_index = to_undirected(data.edge_index, reduce='mean') +if args.add_self_loop: + data.edge_index, _ = remove_self_loops(data.edge_index) + data.edge_index, _ = add_self_loops(data.edge_index, + num_nodes=data.num_nodes) data.to(device, 'x', 'y') @@ -154,25 +168,41 @@ def test(loader: NeighborLoader) -> float: return total_correct / total_examples -if args.use_gat: - model = GAT( - in_channels=dataset.num_features, - hidden_channels=num_hidden_channels, - num_layers=num_layers, - out_channels=dataset.num_classes, - dropout=args.dropout, - heads=args.num_heads, - ) -else: - model = GraphSAGE( - in_channels=dataset.num_features, - hidden_channels=num_hidden_channels, - num_layers=num_layers, - out_channels=dataset.num_classes, - dropout=args.dropout, - ) - -model = model.to(device) +def get_model(gnn_choice: str) -> torch.nn.Module: + if gnn_choice == 'gat': + model = GAT( + in_channels=dataset.num_features, + hidden_channels=num_hidden_channels, + num_layers=num_layers, + out_channels=dataset.num_classes, + dropout=args.dropout, + heads=args.num_heads, + ) + elif gnn_choice == 'graphsage': + model = GraphSAGE( + in_channels=dataset.num_features, + hidden_channels=num_hidden_channels, + num_layers=num_layers, + out_channels=dataset.num_classes, + dropout=args.dropout, + ) + elif gnn_choice == 'sgformer': + model = SGFormer( + in_channels=dataset.num_features, + hidden_channels=num_hidden_channels, + out_channels=dataset.num_classes, + trans_num_heads=args.num_heads, + trans_dropout=args.dropout, + gnn_num_layers=num_layers, + gnn_dropout=args.dropout, + ) + else: + raise ValueError(f'Unsupported model type: {gnn_choice}') + + return model + + +model = get_model(args.gnn_choice).to(device) model.reset_parameters() optimizer = torch.optim.Adam( model.parameters(), diff --git a/examples/ogbn_train_cugraph.py b/examples/ogbn_train_cugraph.py index bcd0b36238bc..b6e53a4785a5 100644 --- a/examples/ogbn_train_cugraph.py +++ b/examples/ogbn_train_cugraph.py @@ -27,10 +27,8 @@ # Enable cudf spilling to save gpu memory from cugraph_pyg.loader import NeighborLoader # noqa from ogb.nodeproppred import PygNodePropPredDataset # noqa -from tqdm import tqdm # noqa import torch_geometric # noqa -from torch_geometric.utils import to_undirected # noqa cudf.set_option("spill", True) @@ -41,7 +39,7 @@ def arg_parse(): parser.add_argument( '--dataset', type=str, - default='ogbn-papers100M', + default='ogbn-arxiv', choices=['ogbn-papers100M', 'ogbn-products', 'ogbn-arxiv'], help='Dataset name.', ) @@ -54,10 +52,10 @@ def arg_parse(): parser.add_argument( "--dataset_subdir", type=str, - default="ogb-papers100M", + default="ogbn-arxiv", help="directory of dataset.", ) - parser.add_argument('-e', '--epochs', type=int, default=10) + parser.add_argument('-e', '--epochs', type=int, default=50) parser.add_argument('--num_layers', type=int, default=3) parser.add_argument('-b', '--batch_size', type=int, default=1024) parser.add_argument('--fan_out', type=int, default=10) @@ -72,18 +70,23 @@ def arg_parse(): action='store_true', help='Whether or not to use directed graph', ) + parser.add_argument( + '--add_self_loop', + action='store_true', + help='Whether or not to add self loop', + ) parser.add_argument( "--model", type=str, - default='SAGE', - choices=['SAGE', 'GAT', 'GCN'], - help="Model used for training, default GraphSAGE", + default='SGFormer', + choices=['SAGE', 'GAT', 'GCN', 'SGFormer'], + help="Model used for training, default SGFormer", ) parser.add_argument( - "--num_gat_conv_heads", + "--num_heads", type=int, - default=4, - help="If using GATConv, number of attention heads to use", + default=1, + help="If using GATConv or GT, number of attention heads to use", ) parser.add_argument('--tempdir_root', type=str, default=None) args = parser.parse_args() @@ -153,6 +156,7 @@ def test(model, loader): if __name__ == '__main__': args = arg_parse() + torch_geometric.seed_everything(123) if "papers" in str(args.dataset) and (psutil.virtual_memory().total / (1024**3)) < 390: print("Warning: may not have enough RAM to use this many GPUs.") @@ -167,7 +171,13 @@ def test(model, loader): data = dataset[0] if not args.use_directed_graph: - data.edge_index = to_undirected(data.edge_index, reduce="mean") + data.edge_index = torch_geometric.utils.to_undirected( + data.edge_index, reduce="mean") + if args.add_self_loop: + data.edge_index, _ = torch_geometric.utils.remove_self_loops( + data.edge_index) + data.edge_index, _ = torch_geometric.utils.add_self_loops( + data.edge_index, num_nodes=data.num_nodes) graph_store = cugraph_pyg.data.GraphStore() graph_store[dict( @@ -178,18 +188,19 @@ def test(model, loader): )] = data.edge_index feature_store = cugraph_pyg.data.TensorDictFeatureStore() - feature_store['node', 'x'] = data.x - feature_store['node', 'y'] = data.y + feature_store['node', 'x', None] = data.x + feature_store['node', 'y', None] = data.y data = (feature_store, graph_store) + print(f"Training {args.dataset} with {args.model} model.") if args.model == "GAT": - print(f"Training {args.dataset} with GAT model.") - model = torch_geometric.nn.models.GAT( - dataset.num_features, args.hidden_channels, args.num_layers, - dataset.num_classes, heads=args.num_gat_conv_heads).cuda() + model = torch_geometric.nn.models.GAT(dataset.num_features, + args.hidden_channels, + args.num_layers, + dataset.num_classes, + heads=args.num_heads).cuda() elif args.model == "GCN": - print(f"Training {args.dataset} with GCN model.") model = torch_geometric.nn.models.GCN( dataset.num_features, args.hidden_channels, @@ -197,13 +208,22 @@ def test(model, loader): dataset.num_classes, ).cuda() elif args.model == "SAGE": - print(f"Training {args.dataset} with GraphSAGE model.") model = torch_geometric.nn.models.GraphSAGE( dataset.num_features, args.hidden_channels, args.num_layers, dataset.num_classes, ).cuda() + elif args.model == 'SGFormer': + model = torch_geometric.nn.models.SGFormer( + in_channels=dataset.num_features, + hidden_channels=args.hidden_channels, + out_channels=dataset.num_classes, + trans_num_heads=args.num_heads, + trans_dropout=args.dropout, + gnn_num_layers=args.num_layers, + gnn_dropout=args.dropout, + ).cuda() else: raise ValueError('Unsupported model type: {args.model}') diff --git a/test/nn/models/test_sgformer.py b/test/nn/models/test_sgformer.py new file mode 100644 index 000000000000..8d5b1c3dfe5c --- /dev/null +++ b/test/nn/models/test_sgformer.py @@ -0,0 +1,19 @@ +import torch + +from torch_geometric.nn.models import SGFormer + + +def test_sgformer(): + x = torch.ones(10, 16, dtype=torch.float32) + edge_index = torch.tensor([ + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], + [1, 2, 3, 4, 0, 6, 7, 8, 9, 5], + ]) + + model = SGFormer( + in_channels=16, + hidden_channels=128, + out_channels=40, + ) + out = model(x, edge_index) + assert out.size() == (10, 40) diff --git a/torch_geometric/nn/attention/__init__.py b/torch_geometric/nn/attention/__init__.py index 6b4064cd34b9..dd2e0a039674 100644 --- a/torch_geometric/nn/attention/__init__.py +++ b/torch_geometric/nn/attention/__init__.py @@ -1,7 +1,9 @@ from .performer import PerformerAttention from .qformer import QFormer +from .sgformer import SGFormerAttention __all__ = [ 'PerformerAttention', 'QFormer', + 'SGFormerAttention', ] diff --git a/torch_geometric/nn/attention/sgformer.py b/torch_geometric/nn/attention/sgformer.py new file mode 100644 index 000000000000..81d0dbd7e505 --- /dev/null +++ b/torch_geometric/nn/attention/sgformer.py @@ -0,0 +1,78 @@ +import torch +from torch import Tensor + + +class SGFormerAttention(torch.nn.Module): + r"""The simple global attention mechanism from the + `"SGFormer: Simplifying and Empowering Transformers for + Large-Graph Representations" + `_ paper. + + Args: + channels (int): Size of each input sample. + heads (int, optional): Number of parallel attention heads. + (default: :obj:`1.`) + head_channels (int, optional): Size of each attention head. + (default: :obj:`64.`) + qkv_bias (bool, optional): If specified, add bias to query, key + and value in the self attention. (default: :obj:`False`) + """ + def __init__( + self, + channels: int, + heads: int = 1, + head_channels: int = 64, + qkv_bias: bool = False, + ) -> None: + super().__init__() + assert channels % heads == 0 + assert heads == 1, 'The number of heads are fixed as 1.' + if head_channels is None: + head_channels = channels // heads + + self.heads = heads + self.head_channels = head_channels + + inner_channels = head_channels * heads + self.q = torch.nn.Linear(channels, inner_channels, bias=qkv_bias) + self.k = torch.nn.Linear(channels, inner_channels, bias=qkv_bias) + self.v = torch.nn.Linear(channels, inner_channels, bias=qkv_bias) + + def forward(self, x: Tensor) -> Tensor: + # feature transformation + qs = self.q(x).reshape(-1, self.heads, self.head_channels) + ks = self.k(x).reshape(-1, self.heads, self.head_channels) + vs = self.v(x).reshape(-1, self.heads, self.head_channels) + + # normalize input + qs = qs / torch.norm(qs, p=2) # [N, H, M] + ks = ks / torch.norm(ks, p=2) # [L, H, M] + N = qs.shape[0] + + # numerator + kvs = torch.einsum("lhm,lhd->hmd", ks, vs) + attention_num = torch.einsum("nhm,hmd->nhd", qs, kvs) # [N, H, D] + attention_num += N * vs + + # denominator + all_ones = torch.ones([ks.shape[0]]).to(ks.device) + ks_sum = torch.einsum("lhm,l->hm", ks, all_ones) + attention_normalizer = torch.einsum("nhm,hm->nh", qs, ks_sum) # [N, H] + + # attentive aggregated results + attention_normalizer = torch.unsqueeze( + attention_normalizer, len(attention_normalizer.shape)) # [N, H, 1] + attention_normalizer += torch.ones_like(attention_normalizer) * N + attn_output = attention_num / attention_normalizer # [N, H, D] + + return attn_output.mean(dim=1) + + def reset_parameters(self): + self.q.reset_parameters() + self.k.reset_parameters() + self.v.reset_parameters() + + def __repr__(self) -> str: + return (f'{self.__class__.__name__}(' + f'heads={self.heads}, ' + f'head_channels={self.head_channels})') diff --git a/torch_geometric/nn/models/__init__.py b/torch_geometric/nn/models/__init__.py index 9ade58cebc05..3a0a4a7fad44 100644 --- a/torch_geometric/nn/models/__init__.py +++ b/torch_geometric/nn/models/__init__.py @@ -32,6 +32,7 @@ from .git_mol import GITMol from .molecule_gpt import MoleculeGPT from .glem import GLEM +from .sgformer import SGFormer # Deprecated: from torch_geometric.explain.algorithm.captum import (to_captum_input, captum_output_to_dicts) @@ -82,4 +83,5 @@ 'GITMol', 'MoleculeGPT', 'GLEM', + 'SGFormer', ] diff --git a/torch_geometric/nn/models/sgformer.py b/torch_geometric/nn/models/sgformer.py new file mode 100644 index 000000000000..a7bbb15b9251 --- /dev/null +++ b/torch_geometric/nn/models/sgformer.py @@ -0,0 +1,190 @@ +import torch +import torch.nn.functional as F + +from torch_geometric.nn.attention import SGFormerAttention +from torch_geometric.nn.conv import GCNConv + + +class GraphModule(torch.nn.Module): + def __init__( + self, + in_channels, + hidden_channels, + num_layers=2, + dropout=0.5, + ): + super().__init__() + + self.convs = torch.nn.ModuleList() + self.fcs = torch.nn.ModuleList() + self.fcs.append(torch.nn.Linear(in_channels, hidden_channels)) + + self.bns = torch.nn.ModuleList() + self.bns.append(torch.nn.BatchNorm1d(hidden_channels)) + for _ in range(num_layers): + self.convs.append(GCNConv(hidden_channels, hidden_channels)) + self.bns.append(torch.nn.BatchNorm1d(hidden_channels)) + + self.dropout = dropout + self.activation = F.relu + + def reset_parameters(self): + for conv in self.convs: + conv.reset_parameters() + for bn in self.bns: + bn.reset_parameters() + for fc in self.fcs: + fc.reset_parameters() + + def forward(self, x, edge_index): + x = self.fcs[0](x) + x = self.bns[0](x) + x = self.activation(x) + x = F.dropout(x, p=self.dropout, training=self.training) + last_x = x + + for i, conv in enumerate(self.convs): + x = conv(x, edge_index) + x = self.bns[i + 1](x) + x = self.activation(x) + x = F.dropout(x, p=self.dropout, training=self.training) + x = x + last_x + return x + + +class SGModule(torch.nn.Module): + def __init__( + self, + in_channels, + hidden_channels, + num_layers=2, + num_heads=1, + dropout=0.5, + ): + super().__init__() + + self.attns = torch.nn.ModuleList() + self.fcs = torch.nn.ModuleList() + self.fcs.append(torch.nn.Linear(in_channels, hidden_channels)) + self.bns = torch.nn.ModuleList() + self.bns.append(torch.nn.LayerNorm(hidden_channels)) + for _ in range(num_layers): + self.attns.append( + SGFormerAttention(hidden_channels, num_heads, hidden_channels)) + self.bns.append(torch.nn.LayerNorm(hidden_channels)) + + self.dropout = dropout + self.activation = F.relu + + def reset_parameters(self): + for attn in self.attns: + attn.reset_parameters() + for bn in self.bns: + bn.reset_parameters() + for fc in self.fcs: + fc.reset_parameters() + + def forward(self, x): + layer_ = [] + + # input MLP layer + x = self.fcs[0](x) + x = self.bns[0](x) + x = self.activation(x) + x = F.dropout(x, p=self.dropout, training=self.training) + + # store as residual link + layer_.append(x) + + for i, attn in enumerate(self.attns): + x = attn(x) + x = (x + layer_[i]) / 2. + x = self.bns[i + 1](x) + x = self.activation(x) + x = F.dropout(x, p=self.dropout, training=self.training) + layer_.append(x) + + return x + + +class SGFormer(torch.nn.Module): + r"""The sgformer module from the + `"SGFormer: Simplifying and Empowering Transformers for + Large-Graph Representations" + `_ paper. + + Args: + in_channels (int): Input channels. + hidden_channels (int): Hidden channels. + out_channels (int): Output channels. + trans_num_layers (int): The number of layers for all-pair attention. + (default: :obj:`2`) + trans_num_heads (int): The number of heads for attention. + (default: :obj:`1`) + trans_dropout (float): Global dropout rate. + (default: :obj:`0.5`) + gnn_num_layers (int): The number of layers for GNN. + (default: :obj:`3`) + gnn_dropout (float): GNN dropout rate. + (default: :obj:`0.5`) + graph_weight (float): The weight balance global and gnn module. + (default: :obj:`0.5`) + aggregate (str): Aggregate type. + (default: :obj:`add`) + """ + def __init__( + self, + in_channels: int, + hidden_channels: int, + out_channels: int, + trans_num_layers: int = 2, + trans_num_heads: int = 1, + trans_dropout: float = 0.5, + gnn_num_layers: int = 3, + gnn_dropout: float = 0.5, + graph_weight: float = 0.5, + aggregate: str = 'add', + ): + super().__init__() + self.trans_conv = SGModule( + in_channels, + hidden_channels, + trans_num_layers, + trans_num_heads, + trans_dropout, + ) + self.graph_conv = GraphModule( + in_channels, + hidden_channels, + gnn_num_layers, + gnn_dropout, + ) + self.graph_weight = graph_weight + + self.aggregate = aggregate + + if aggregate == 'add': + self.fc = torch.nn.Linear(hidden_channels, out_channels) + elif aggregate == 'cat': + self.fc = torch.nn.Linear(2 * hidden_channels, out_channels) + else: + raise ValueError(f'Invalid aggregate type:{aggregate}') + + self.params1 = list(self.trans_conv.parameters()) + self.params2 = list(self.graph_conv.parameters()) + self.params2.extend(list(self.fc.parameters())) + + def reset_parameters(self) -> None: + self.trans_conv.reset_parameters() + self.graph_conv.reset_parameters() + self.fc.reset_parameters() + + def forward(self, x, edge_index): + x1 = self.trans_conv(x) + x2 = self.graph_conv(x, edge_index) + if self.aggregate == 'add': + x = self.graph_weight * x2 + (1 - self.graph_weight) * x1 + else: + x = torch.cat((x1, x2), dim=1) + x = self.fc(x) + return F.log_softmax(x, dim=-1)