-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathlayers.py
278 lines (226 loc) · 8.86 KB
/
layers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
#! /usr/bin/env python
# -*- coding: utf-8 -*-
# vim:fenc=utf-8
"""
some functions are from AllSet.
https://arxiv.org/abs/2106.13264
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.nn import Linear
from torch.nn import Parameter
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.utils import softmax
from torch_scatter import scatter
from torch_geometric.typing import Adj, OptTensor, SparseTensor
def get_activation(act, inplace=False):
"""
Parameters
----------
act
Name of the activation
inplace
Whether to perform inplace activation
Returns
-------
activation_layer
The activation
"""
if act is None:
return lambda x: x
if isinstance(act, str):
if act == 'leaky':
# TODO(sxjscience) Add regex matching here to parse `leaky(0.1)`
return nn.LeakyReLU(0.1, inplace=inplace)
if act == 'identity':
return nn.Identity()
if act == 'elu':
return nn.ELU(inplace=inplace)
if act == 'gelu':
return nn.GELU()
if act == 'relu':
return nn.ReLU()
if act == 'sigmoid':
return nn.Sigmoid()
if act == 'tanh':
return nn.Tanh()
if act in {'softrelu', 'softplus'}:
return nn.Softplus()
if act == 'softsign':
return nn.Softsign()
raise NotImplementedError('act="{}" is not supported. '
'Try to include it if you can find that in '
'https://pytorch.org/docs/stable/nn.html'.format(act))
return act
class PositionwiseFFN(nn.Module):
"""The Position-wise FFN layer used in Transformer-like architectures
If pre_norm is True:
norm(data) -> fc1 -> act -> act_dropout -> fc2 -> dropout -> res(+data)
Else:
data -> fc1 -> act -> act_dropout -> fc2 -> dropout -> norm(res(+data))
Also, if we use gated projection. We will use
fc1_1 * act(fc1_2(data)) to map the data
"""
def __init__(self, config):
"""
Parameters
----------
units
hidden_size
activation_dropout
dropout
activation
normalization
layer_norm or no_norm
layer_norm_eps
pre_norm
Pre-layer normalization as proposed in the paper:
"[ACL2018] The Best of Both Worlds: Combining Recent Advances in
Neural Machine Translation"
This will stabilize the training of Transformers.
You may also refer to
"[Arxiv2020] Understanding the Difficulty of Training Transformers"
"""
super().__init__()
self.config = config
self.dropout_layer = nn.Dropout(self.config.hidden_dropout_prob)
self.activation_dropout_layer = nn.Dropout(self.config.activation_dropout)
self.ffn_1 = nn.Linear(in_features=self.config.hidden_size, out_features=self.config.intermediate_size,
bias=True)
if self.config.gated_proj:
self.ffn_1_gate = nn.Linear(in_features=self.config.hidden_size,
out_features=self.config.hidden_size,
bias=True)
self.activation = get_activation(self.config.hidden_act)
self.ffn_2 = nn.Linear(in_features=self.config.intermediate_size, out_features=self.config.hidden_size,
bias=True)
self.layer_norm = nn.LayerNorm(eps=self.config.layer_norm_eps,
normalized_shape=self.config.hidden_size)
self.init_weights()
def init_weights(self):
for module in self.children():
if isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
def forward(self, data):
"""
Parameters
----------
data :
Shape (B, seq_length, C_in)
Returns
-------
out :
Shape (B, seq_length, C_out)
"""
residual = data
if self.config.pre_norm:
data = self.layer_norm(data)
if self.config.gated_proj:
out = self.activation(self.ffn_1_gate(data)) * self.ffn_1(data)
else:
out = self.activation(self.ffn_1(data))
out = self.activation_dropout_layer(out)
out = self.ffn_2(out)
out = self.dropout_layer(out)
out = out + residual
if not self.config.pre_norm:
out = self.layer_norm(out)
return out
# Method for initialization
def glorot(tensor):
if tensor is not None:
stdv = math.sqrt(6.0 / (tensor.size(-2) + tensor.size(-1)))
tensor.data.uniform_(-stdv, stdv)
def zeros(tensor):
if tensor is not None:
tensor.data.fill_(0)
class AllSetTrans(MessagePassing):
"""
AllSetTrans part:
Note that in original PMA, we need to compute the inner product of the seed and neighbor nodes.
i.e. e_ij = a(Wh_i,Wh_j), where a should be the inner product, h_i is the seed and h_j are neightbor nodes.
In GAT, a(x,y) = a^T[x||y]. We use the same logic.
"""
def __init__(self, config, negative_slope=0.2, **kwargs):
super(AllSetTrans, self).__init__(node_dim=0, **kwargs)
self.in_channels = config.hidden_size
self.heads = config.num_attention_heads
self.hidden = config.hidden_size // self.heads
self.out_channels = config.hidden_size
self.negative_slope = negative_slope
self.dropout = config.attention_probs_dropout_prob
self.aggr = 'add'
self.lin_K = Linear(self.in_channels, self.heads * self.hidden)
self.lin_V = Linear(self.in_channels, self.heads * self.hidden)
self.att_r = Parameter(torch.Tensor(1, self.heads, self.hidden)) # Seed vector
self.rFF = PositionwiseFFN(config)
self.ln0 = nn.LayerNorm(self.heads * self.hidden)
self.ln1 = nn.LayerNorm(self.heads * self.hidden)
self._alpha = None
self.reset_parameters()
def reset_parameters(self):
glorot(self.lin_K.weight)
glorot(self.lin_V.weight)
self.ln0.reset_parameters()
self.ln1.reset_parameters()
nn.init.xavier_uniform_(self.att_r)
def forward(self, x, edge_index: Adj, return_attention_weights=None):
"""
Args:
return_attention_weights (bool, optional): If set to :obj:`True`,
will additionally return the tuple
:obj:`(edge_index, attention_weights)`, holding the computed
attention weights for each edge. (default: :obj:`None`)
"""
H, C = self.heads, self.hidden
alpha_r: OptTensor = None
assert x.dim() == 2, 'Static graphs not supported in `GATConv`.'
x_K = self.lin_K(x).view(-1, H, C)
x_V = self.lin_V(x).view(-1, H, C)
alpha_r = (x_K * self.att_r).sum(dim=-1)
out = self.propagate(edge_index, x=x_V,
alpha=alpha_r, aggr=self.aggr)
alpha = self._alpha
self._alpha = None
out += self.att_r # Seed + Multihead
# concat heads then LayerNorm.
out = self.ln0(out.view(-1, self.heads * self.hidden))
# rFF and skip connection.
out = self.ln1(out + F.relu(self.rFF(out)))
if isinstance(return_attention_weights, bool):
assert alpha is not None
if isinstance(edge_index, Tensor):
return out, (edge_index, alpha)
elif isinstance(edge_index, SparseTensor):
return out, edge_index.set_value(alpha, layout='coo')
else:
return out
def message(self, x_j, alpha_j,
index, ptr,):
alpha = alpha_j
alpha = F.leaky_relu(alpha, self.negative_slope)
alpha = softmax(alpha, index, ptr, index.max() + 1)
self._alpha = alpha
alpha = F.dropout(alpha, p=self.dropout, training=self.training)
return x_j * alpha.unsqueeze(-1)
def aggregate(self, inputs, index, aggr=None):
r"""Aggregates messages from neighbors as
:math:`\square_{j \in \mathcal{N}(i)}`.
Takes in the output of message computation as first argument and any
argument which was initially passed to :meth:`propagate`.
By default, this function will delegate its call to scatter functions
that support "add", "mean" and "max" operations as specified in
:meth:`__init__` by the :obj:`aggr` argument.
"""
if aggr is None:
aggr = self.aggr
return scatter(inputs, index, dim=self.node_dim, reduce=aggr)
def __repr__(self):
return '{}({}, {}, heads={})'.format(self.__class__.__name__,
self.in_channels,
self.out_channels, self.heads)