forked from tensorflow/models
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutil.py
72 lines (58 loc) · 2.67 KB
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Keras-based transformer block layer."""
import functools
import tensorflow as tf, tf_keras
class TfFunctionIfEagerDecorator(object):
"""Helper decorator function to optionally apply the @tf.function annotation."""
def __init__(self, **kwargs):
self.func_kwargs = kwargs
def __call__(self, func):
@functools.wraps(func)
def wrapped_func(*args):
# TODO(b/150147476, b/150024785): Fix tf.function in TF1 crash.
if not hasattr(tf.compat.v1, 'executing_eagerly_outside_functions'
) or tf.compat.v1.executing_eagerly_outside_functions():
return tf.function(func=func, **self.func_kwargs)(*args)
return func(*args)
# Cache the created function in self._call_impl.
if not hasattr(self, '_call_impl'):
self._call_impl = wrapped_func
return self._call_impl
def tf_function_if_eager(**kwargs):
"""Applies the @tf.function decorator only if running in eager mode."""
return TfFunctionIfEagerDecorator(**kwargs)
def filter_kwargs(kwargs):
"""In place removes unused options in kwargs.
This function removes the construction signatures: e.g.
number_attention_heads... in TransformerEncoderBlock. This is needed,
otherwise base_layer.py in Keras will complain.
Args:
kwargs: keyword arguments to be filtered.
"""
# This is the union of signatures of TransformerEncoderBlock and
# ReZeroTransformer. Every Transformer
# block that uses compatible signature with TransformerEncoderBlock should
# call this function before base constructor super().__init__(**kwargs).
denylist = [
'num_attention_heads', 'intermediate_size', 'intermediate_activation',
'inner_dim', 'inner_activation', 'output_range', 'kernel_initializer',
'bias_initializer', 'kernel_regularizer', 'bias_regularizer',
'activity_regularizer', 'kernel_constraint', 'bias_constraint',
'use_bias', 'norm_first', 'norm_epsilon', 'output_dropout',
'attention_dropout', 'inner_dropout', 'attention_initializer',
'attention_axes', 'share_rezero'
]
for unused_key in denylist:
kwargs.pop(unused_key, None)