-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathattention.py
122 lines (98 loc) · 4.04 KB
/
attention.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
from keras.layers.core import Layer
from keras import initializers, regularizers, constraints
from keras import backend as K
"""
Attention code by Christos Baziotis https://gist.github.com/cbaziotis/6428df359af27d58078ca5ed9792bd6d,
follows work of Raffel et al. [https://arxiv.org/abs/1512.08756]
"""
def dot_product(x, kernel):
"""
Wrapper for dot product operation, in order to be compatible with both
Theano and Tensorflow
Args:
x (): input
kernel (): weights
Returns:
"""
if K.backend() == 'tensorflow':
# todo: check that this is correct
return K.squeeze(K.dot(x, K.expand_dims(kernel)), axis=-1)
else:
return K.dot(x, kernel)
class Attention(Layer):
def __init__(self,
kernel_regularizer=None, bias_regularizer=None,
W_constraint=None, b_constraint=None,
bias=True,
return_attention=False,
**kwargs):
"""
Keras Layer that implements an Attention mechanism for temporal data.
Supports Masking.
Follows the work of Raffel et al. [https://arxiv.org/abs/1512.08756]
# Input shape
3D tensor with shape: `(samples, steps, features)`.
# Output shape
2D tensor with shape: `(samples, features)`.
:param kwargs:
Just put it on top of an RNN Layer (GRU/LSTM/SimpleRNN) with return_sequences=True.
The dimensions are inferred based on the output shape of the RNN.
Note: The layer has been tested with Keras 1.x
Example:
model.add(LSTM(64, return_sequences=True))
model.add(Attention())
# next add a Dense layer (for classification/regression) or whatever...
"""
self.supports_masking = True
self.init = initializers.get('glorot_uniform')
self.W_regularizer = regularizers.get(kernel_regularizer)
self.b_regularizer = regularizers.get(bias_regularizer)
self.W_constraint = constraints.get(W_constraint)
self.b_constraint = constraints.get(b_constraint)
self.bias = bias
self.return_attention = return_attention
super(Attention, self).__init__(**kwargs)
def build(self, input_shape):
assert len(input_shape) == 3
self.W = self.add_weight((input_shape[-1],),
initializer=self.init,
name='{}_W'.format(self.name),
regularizer=self.W_regularizer,
constraint=self.W_constraint)
if self.bias:
self.b = self.add_weight((input_shape[1],),
initializer='zero',
name='{}_b'.format(self.name),
regularizer=self.b_regularizer,
constraint=self.b_constraint)
else:
self.b = None
self.built = True
def compute_mask(self, input, input_mask=None):
# do not pass the mask to the next layers
if self.return_attention:
return [None, None]
return None
def call(self, x, mask=None):
eij = dot_product(x, self.W)
if self.bias:
eij += self.b
eij = K.tanh(eij)
a = K.exp(eij)
# apply mask after the exp. will be re-normalized next
if mask is not None:
# Cast the mask to floatX to avoid float64 upcasting in theano
a *= K.cast(mask, K.floatx())
a /= K.cast(K.sum(a, axis=1, keepdims=True) + K.epsilon(), K.floatx())
a = K.expand_dims(a)
weighted_input = x * a
result = K.sum(weighted_input, axis=1)
if self.return_attention:
return [result, a]
return result
def compute_output_shape(self, input_shape):
if self.return_attention:
return [(input_shape[0], input_shape[-1]),
(input_shape[0], input_shape[1])]
else:
return input_shape[0], input_shape[-1]