-
Notifications
You must be signed in to change notification settings - Fork 150
/
gat_predictor.py
139 lines (125 loc) · 6.72 KB
/
gat_predictor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
# -*- coding: utf-8 -*-
#
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
#
# GAT-based model for regression and classification on graphs.
# pylint: disable= no-member, arguments-differ, invalid-name
import torch.nn as nn
from .mlp_predictor import MLPPredictor
from ..gnn.gat import GAT
from ..readout.weighted_sum_and_max import WeightedSumAndMax
# pylint: disable=W0221
class GATPredictor(nn.Module):
r"""GAT-based model for regression and classification on graphs.
GAT is introduced in `Graph Attention Networks <https://arxiv.org/abs/1710.10903>`__.
This model is based on GAT and can be used for regression and classification on graphs.
After updating node representations, we perform a weighted sum with learnable
weights and max pooling on them and concatenate the output of the two operations,
which is then fed into an MLP for final prediction.
For classification tasks, the output will be logits, i.e.
values before sigmoid or softmax.
Parameters
----------
in_feats : int
Number of input node features
hidden_feats : list of int
``hidden_feats[i]`` gives the output size of an attention head in the i-th GAT layer.
``len(hidden_feats)`` equals the number of GAT layers. By default, we use ``[32, 32]``.
num_heads : list of int
``num_heads[i]`` gives the number of attention heads in the i-th GAT layer.
``len(num_heads)`` equals the number of GAT layers. By default, we use 4 attention heads
for each GAT layer.
feat_drops : list of float
``feat_drops[i]`` gives the dropout applied to the input features in the i-th GAT layer.
``len(feat_drops)`` equals the number of GAT layers. By default, this will be zero for
all GAT layers.
attn_drops : list of float
``attn_drops[i]`` gives the dropout applied to attention values of edges in the i-th GAT
layer. ``len(attn_drops)`` equals the number of GAT layers. By default, this will be zero
for all GAT layers.
alphas : list of float
Hyperparameters in LeakyReLU, which are the slopes for negative values. ``alphas[i]``
gives the slope for negative value in the i-th GAT layer. ``len(alphas)`` equals the
number of GAT layers. By default, this will be 0.2 for all GAT layers.
residuals : list of bool
``residual[i]`` decides if residual connection is to be used for the i-th GAT layer.
``len(residual)`` equals the number of GAT layers. By default, residual connection
is performed for each GAT layer.
agg_modes : list of str
The way to aggregate multi-head attention results for each GAT layer, which can be either
'flatten' for concatenating all-head results or 'mean' for averaging all-head results.
``agg_modes[i]`` gives the way to aggregate multi-head attention results for the i-th
GAT layer. ``len(agg_modes)`` equals the number of GAT layers. By default, we flatten
multi-head results for intermediate GAT layers and compute mean of multi-head results
for the last GAT layer.
activations : list of activation function or None
``activations[i]`` gives the activation function applied to the aggregated multi-head
results for the i-th GAT layer. ``len(activations)`` equals the number of GAT layers.
By default, ELU is applied for intermediate GAT layers and no activation is applied
for the last GAT layer.
biases : list of bool
``biases[i]`` gives whether to add bias for the i-th GAT layer. ``len(activations)``
equals the number of GAT layers. By default, bias is added for all GAT layers.
classifier_hidden_feats : int
(Deprecated, see ``predictor_hidden_feats``) Size of hidden graph representations
in the classifier. Default to 128.
classifier_dropout : float
(Deprecated, see ``predictor_dropout``) The probability for dropout in the classifier.
Default to 0.
n_tasks : int
Number of tasks, which is also the output size. Default to 1.
predictor_hidden_feats : int
Size for hidden representations in the output MLP predictor. Default to 128.
predictor_dropout : float
The probability for dropout in the output MLP predictor. Default to 0.
"""
def __init__(self, in_feats, hidden_feats=None, num_heads=None, feat_drops=None,
attn_drops=None, alphas=None, residuals=None, agg_modes=None, activations=None,
biases=None, classifier_hidden_feats=128, classifier_dropout=0., n_tasks=1,
predictor_hidden_feats=128, predictor_dropout=0.):
super(GATPredictor, self).__init__()
if predictor_hidden_feats == 128 and classifier_hidden_feats != 128:
print('classifier_hidden_feats is deprecated and will be removed in the future, '
'use predictor_hidden_feats instead')
predictor_hidden_feats = classifier_hidden_feats
if predictor_dropout == 0. and classifier_dropout != 0.:
print('classifier_dropout is deprecated and will be removed in the future, '
'use predictor_dropout instead')
predictor_dropout = classifier_dropout
self.gnn = GAT(in_feats=in_feats,
hidden_feats=hidden_feats,
num_heads=num_heads,
feat_drops=feat_drops,
attn_drops=attn_drops,
alphas=alphas,
residuals=residuals,
agg_modes=agg_modes,
activations=activations,
biases=biases)
if self.gnn.agg_modes[-1] == 'flatten':
gnn_out_feats = self.gnn.hidden_feats[-1] * self.gnn.num_heads[-1]
else:
gnn_out_feats = self.gnn.hidden_feats[-1]
self.readout = WeightedSumAndMax(gnn_out_feats)
self.predict = MLPPredictor(2 * gnn_out_feats, predictor_hidden_feats,
n_tasks, predictor_dropout)
def forward(self, bg, feats):
"""Graph-level regression/soft classification.
Parameters
----------
bg : DGLGraph
DGLGraph for a batch of graphs.
feats : FloatTensor of shape (N, M1)
* N is the total number of nodes in the batch of graphs
* M1 is the input node feature size, which must match
in_feats in initialization
Returns
-------
FloatTensor of shape (B, n_tasks)
* Predictions on graphs
* B for the number of graphs in the batch
"""
node_feats = self.gnn(bg, feats)
graph_feats = self.readout(bg, node_feats)
return self.predict(graph_feats)