-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathgcn_predictor.py
106 lines (96 loc) · 4.89 KB
/
gcn_predictor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import torch.nn as nn
from .mlp_predictor import MLPPredictor
from ..gnn.gcn import GCN
from ..readout.weighted_sum_and_max import WeightedSumAndMax
class GCNPredictor(nn.Module):
"""
Parameters
----------
in_feats : int
Number of input node features.
hidden_feats : list of int
``hidden_feats[i]`` gives the size of node representations after the i-th GCN layer.
``len(hidden_feats)`` equals the number of GCN layers. By default, we use
``[64, 64]``.
gnn_norm : list of str
``gnn_norm[i]`` gives the message passing normalizer for the i-th GCN layer, which
can be `'right'`, `'both'` or `'none'`. The `'right'` normalizer divides the aggregated
messages by each node's in-degree. The `'both'` normalizer corresponds to the symmetric
adjacency normalization in the original GCN paper. The `'none'` normalizer simply sums
the messages. ``len(gnn_norm)`` equals the number of GCN layers. By default, we use
``['none', 'none']``.
activation : list of activation functions or None
If None, no activation will be applied. If not None, ``activation[i]`` gives the
activation function to be used for the i-th GCN layer. ``len(activation)`` equals
the number of GCN layers. By default, ReLU is applied for all GCN layers.
residual : list of bool
``residual[i]`` decides if residual connection is to be used for the i-th GCN layer.
``len(residual)`` equals the number of GCN layers. By default, residual connection
is performed for each GCN layer.
batchnorm : list of bool
``batchnorm[i]`` decides if batch normalization is to be applied on the output of
the i-th GCN layer. ``len(batchnorm)`` equals the number of GCN layers. By default,
batch normalization is applied for all GCN layers.
dropout : list of float
``dropout[i]`` decides the dropout probability on the output of the i-th GCN layer.
``len(dropout)`` equals the number of GCN layers. By default, no dropout is
performed for all layers.
classifier_hidden_feats : int
(Deprecated, see ``predictor_hidden_feats``) Size of hidden graph representations
in the classifier. Default to 128.
classifier_dropout : float
(Deprecated, see ``predictor_dropout``) The probability for dropout in the classifier.
Default to 0.
n_tasks : int
Number of tasks, which is also the output size. Default to 1.
predictor_hidden_feats : int
Size for hidden representations in the output MLP predictor. Default to 128.
predictor_dropout : float
The probability for dropout in the output MLP predictor. Default to 0.
"""
def __init__(self,in_feats, hidden_feats=None, gnn_norm=None, activation=None,
residual=None, batchnorm=None, dropout=None, classifier_hidden_feats=128,
classifier_dropout=0., n_tasks=1, predictor_hidden_feats=128,
predictor_dropout=0.):
super(GCNPredictor, self).__init__()
if predictor_hidden_feats == 128 and classifier_hidden_feats != 128:
print('classifier_hidden_feats is deprecated and will be removed in the future, '
'use predictor_hidden_feats instead')
predictor_hidden_feats = classifier_hidden_feats
if predictor_dropout == 0. and classifier_dropout != 0.:
print('classifier_dropout is deprecated and will be removed in the future, '
'use predictor_dropout instead')
predictor_dropout = classifier_dropout
self.gnn = GCN(in_feats=in_feats,
hidden_feats=hidden_feats,
gnn_norm=gnn_norm,
activation=activation,
residual=residual,
batchnorm=batchnorm,
dropout=dropout)
gnn_out_feats = self.gnn.hidden_feats[-1]
self.readout = WeightedSumAndMax(gnn_out_feats)
self.predict = MLPPredictor(2 * gnn_out_feats, predictor_hidden_feats,
n_tasks, predictor_dropout)
def forward(self, bg, feats,model_use):
"""Graph-level regression/soft classification.
Parameters
----------
bg : DGLGraph
DGLGraph for a batch of graphs.
feats : FloatTensor of shape (N, M1)
* N is the total number of nodes in the batch of graphs
* M1 is the input node feature size, which must match
in_feats in initialization
Returns
-------
FloatTensor of shape (B, n_tasks)
* Predictions on graphs
* B for the number of graphs in the batch
"""
node_feats = self.gnn(bg, feats)
graph_feats = self.readout(bg, node_feats)
if model_use == 'a':
return graph_feats
else:
return self.predict(graph_feats) # GCN