-
Notifications
You must be signed in to change notification settings - Fork 3
/
subcluster_adacos.py
65 lines (58 loc) · 2.93 KB
/
subcluster_adacos.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import math
import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow.keras import backend as K
class SCAdaCos(tf.keras.layers.Layer):
def __init__(self, n_classes=10, n_subclusters=1, regularizer=None, **kwargs):
super(SCAdaCos, self).__init__(**kwargs)
self.n_classes = n_classes
self.n_subclusters = n_subclusters
self.s_init = math.sqrt(2) * math.log(n_classes*n_subclusters - 1)
self.regularizer = tf.keras.regularizers.get(regularizer)
def build(self, input_shape):
super(SCAdaCos, self).build(input_shape[0])
self.W = self.add_weight(name='W_AdaCos' + str(self.n_classes) + '_' + str(self.n_subclusters),
shape=(input_shape[0][-1], self.n_classes*self.n_subclusters),
initializer='glorot_uniform',
trainable=False,
regularizer=self.regularizer)
self.s = self.add_weight(name='s' + str(self.n_classes) + '_' + str(self.n_subclusters),
shape=(),
initializer=tf.keras.initializers.Constant(self.s_init),
trainable=False,
aggregation=tf.VariableAggregation.MEAN)
def call(self, inputs, training=None):
x, y1, y2 = inputs
y1_orig = y1
y1 = tf.repeat(y1, repeats=self.n_subclusters, axis=-1)
# normalize feature
x = tf.nn.l2_normalize(x, axis=1)
# normalize weights
W = tf.nn.l2_normalize(self.W, axis=0)
# dot product
logits = x @ W # same as cos theta
theta = tf.acos(K.clip(logits, -1.0 + K.epsilon(), 1.0 - K.epsilon()))
if training:
max_s_logits = tf.reduce_max(self.s * logits)
B_avg = tf.exp(self.s*logits-max_s_logits)
B_avg = tf.reduce_mean(tf.reduce_sum(B_avg, axis=1))
theta_class = tf.reduce_sum(y1 * theta, axis=1) * tf.math.count_nonzero(y1_orig, axis=1, dtype=tf.dtypes.float32) # take mix-upped angle of mix-upped classes
theta_med = tfp.stats.percentile(theta_class, q=50) # computes median
self.s.assign(
(max_s_logits + tf.math.log(B_avg)) /
tf.math.cos(tf.minimum(math.pi / 4, theta_med)) + K.epsilon())
logits *= self.s
out = tf.keras.activations.softmax(logits)
out = tf.reshape(out, (-1, self.n_classes, self.n_subclusters))
out = tf.math.reduce_sum(out, axis=2)
return out
def compute_output_shape(self, input_shape):
return (None, self.n_classes)
def get_config(self):
config = {
'n_classes': self.n_classes,
'regularizer': self.regularizer,
'n_subclusters': self.n_subclusters
}
base_config = super(SCAdaCos, self).get_config()
return dict(list(base_config.items()) + list(config.items()))