-
Notifications
You must be signed in to change notification settings - Fork 0
/
transformer.py
218 lines (204 loc) · 8.09 KB
/
transformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
import tensorflow as tf
def _check_masks_shapes(inputs, padding_mask, attention_mask):
mask = padding_mask
if hasattr(inputs, "_keras_mask") and mask is None:
mask = inputs._keras_mask
if mask is not None:
if len(mask.shape) != 2:
raise ValueError(
"`padding_mask` should have shape "
"(batch_size, target_length). "
f"Received shape `{mask.shape}`."
)
if attention_mask is not None:
if len(attention_mask.shape) != 3:
raise ValueError(
"`attention_mask` should have shape "
"(batch_size, target_length, source_length). "
f"Received shape `{mask.shape}`."
)
def merge_padding_and_attention_mask(
inputs,
padding_mask,
attention_mask,
):
_check_masks_shapes(inputs, padding_mask, attention_mask)
mask = padding_mask
if hasattr(inputs, "_keras_mask"):
if mask is None:
# If no padding mask is explicitly provided, we look for padding
# mask from the input data.
mask = inputs._keras_mask
else:
logging.warning(
"You are explicitly setting `padding_mask` while the `inputs` "
"have built-in mask, so the built-in mask is ignored."
)
if mask is not None:
# Add an axis for broadcasting, the attention mask should be 2D
# (not including the batch axis).
mask = ops.cast(ops.expand_dims(mask, axis=1), "int32")
if attention_mask is not None:
attention_mask = ops.cast(attention_mask, "int32")
if mask is None:
return attention_mask
else:
return ops.minimum(mask, attention_mask)
return mask
def clone_initializer(initializer):
# If we get a string or dict, just return as we cannot and should not clone.
if not isinstance(initializer, tf.keras.initializers.Initializer):
return initializer
config = initializer.get_config()
return initializer.__class__.from_config(config)
class TransformerEncoder(tf.keras.layers.Layer):
def __init__(
self,
intermediate_dim,
num_heads,
dropout=0,
activation="relu",
layer_norm_epsilon=1e-05,
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
normalize_first=False,
**kwargs,
):
super().__init__(**kwargs)
self.intermediate_dim = intermediate_dim
self.num_heads = num_heads
self.dropout = dropout
self.activation = tf.keras.activations.get(activation)
self.layer_norm_epsilon = layer_norm_epsilon
self.kernel_initializer = tf.keras.initializers.get(kernel_initializer)
self.bias_initializer = tf.keras.initializers.get(bias_initializer)
self.normalize_first = normalize_first
self.supports_masking = True
def build(self, inputs_shape):
# Infer the dimension of our hidden feature size from the build shape.
hidden_dim = inputs_shape[-1]
# Attention head size is `hidden_dim` over the number of heads.
key_dim = int(hidden_dim // self.num_heads)
if key_dim == 0:
raise ValueError(
"Attention `key_dim` computed cannot be zero. "
f"The `hidden_dim` value of {hidden_dim} has to be equal to "
f"or greater than `num_heads` value of {self.num_heads}."
)
# Self attention layers.
self._self_attention_layer = tf.keras.layers.MultiHeadAttention(
num_heads=self.num_heads,
key_dim=key_dim,
dropout=self.dropout,
kernel_initializer=clone_initializer(self.kernel_initializer),
bias_initializer=clone_initializer(self.bias_initializer),
dtype=self.dtype_policy,
name="self_attention_layer",
)
if hasattr(self._self_attention_layer, "_build_from_signature"):
self._self_attention_layer._build_from_signature(
query=inputs_shape,
value=inputs_shape,
)
else:
self._self_attention_layer.build(
query_shape=inputs_shape,
value_shape=inputs_shape,
)
self._self_attention_layer_norm = tf.keras.layers.LayerNormalization(
epsilon=self.layer_norm_epsilon,
dtype=self.dtype_policy,
name="self_attention_layer_norm",
)
self._self_attention_layer_norm.build(inputs_shape)
self._self_attention_dropout = tf.keras.layers.Dropout(
rate=self.dropout,
dtype=self.dtype_policy,
name="self_attention_dropout",
)
# Feedforward layers.
self._feedforward_layer_norm = tf.keras.layers.LayerNormalization(
epsilon=self.layer_norm_epsilon,
dtype=self.dtype_policy,
name="feedforward_layer_norm",
)
self._feedforward_layer_norm.build(inputs_shape)
self._feedforward_intermediate_dense = tf.keras.layers.Dense(
self.intermediate_dim,
activation=self.activation,
kernel_initializer=clone_initializer(self.kernel_initializer),
bias_initializer=clone_initializer(self.bias_initializer),
dtype=self.dtype_policy,
name="feedforward_intermediate_dense",
)
self._feedforward_intermediate_dense.build(inputs_shape)
self._feedforward_output_dense = tf.keras.layers.Dense(
hidden_dim,
kernel_initializer=clone_initializer(self.kernel_initializer),
bias_initializer=clone_initializer(self.bias_initializer),
dtype=self.dtype_policy,
name="feedforward_output_dense",
)
intermediate_shape = list(inputs_shape)
intermediate_shape[-1] = self.intermediate_dim
self._feedforward_output_dense.build(tuple(intermediate_shape))
self._feedforward_dropout = tf.keras.layers.Dropout(
rate=self.dropout,
dtype=self.dtype_policy,
name="feedforward_dropout",
)
self.built = True
def call(
self, inputs, padding_mask=None, attention_mask=None, training=None
):
x = inputs # Intermediate result.
# Compute self attention mask.
self_attention_mask = merge_padding_and_attention_mask(
inputs, padding_mask, attention_mask
)
# Self attention block.
residual = x
if self.normalize_first:
x = self._self_attention_layer_norm(x)
x = self._self_attention_layer(
query=x,
value=x,
attention_mask=self_attention_mask,
training=training,
)
x = self._self_attention_dropout(x, training=training)
x = x + residual
if not self.normalize_first:
x = self._self_attention_layer_norm(x)
# Feedforward block.
residual = x
if self.normalize_first:
x = self._feedforward_layer_norm(x)
x = self._feedforward_intermediate_dense(x)
x = self._feedforward_output_dense(x)
x = self._feedforward_dropout(x, training=training)
x = x + residual
if not self.normalize_first:
x = self._feedforward_layer_norm(x)
return x
def get_config(self):
config = super().get_config()
config.update(
{
"intermediate_dim": self.intermediate_dim,
"num_heads": self.num_heads,
"dropout": self.dropout,
"activation": tf.keras.activations.serialize(self.activation),
"layer_norm_epsilon": self.layer_norm_epsilon,
"kernel_initializer": tf.keras.initializers.serialize(
self.kernel_initializer
),
"bias_initializer": tf.keras.initializers.serialize(
self.bias_initializer
),
"normalize_first": self.normalize_first,
}
)
return config
def compute_output_shape(self, inputs_shape):
return inputs_shape