-
Notifications
You must be signed in to change notification settings - Fork 2
/
memory_softmax.py
52 lines (48 loc) · 1.76 KB
/
memory_softmax.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import mxnet as mx
from mxnet import autograd
from mxnet import nd
class MarginLoss(object):
""" Default is Arcface loss
"""
def __init__(self, margins=(1.0, 0.5, 0.0), loss_s=64, embedding_size=512):
"""
"""
# margins
self.loss_m1 = margins[0]
self.loss_m2 = margins[1]
self.loss_m3 = margins[2]
self.loss_s = loss_s
self.embedding_size = embedding_size
def forward(self, data, weight, mapping_label, depth):
"""
"""
with autograd.record():
norm_data = nd.L2Normalization(data)
norm_weight = nd.L2Normalization(weight)
#
fc7 = nd.dot(norm_data, norm_weight, transpose_b=True)
#
mapping_label_onehot = mx.nd.one_hot(
indices=mapping_label, depth=depth,
on_value=1.0, off_value=0.0)
# cosface
if self.loss_m1 == 1.0 and self.loss_m2 == 0.0:
_one_hot = mapping_label_onehot * self.loss_m3
fc7 = fc7 - _one_hot
else:
fc7_onehot = fc7 * mapping_label_onehot
cos_t = fc7_onehot
t = nd.arccos(cos_t)
if self.loss_m1 != 1.0:
t = t * self.loss_m1
if self.loss_m2 != 0.0:
t = t + self.loss_m2
margin_cos = nd.cos(t)
if self.loss_m3 != 0.0:
margin_cos = margin_cos - self.loss_m3
margin_fc7 = margin_cos
margin_fc7_onehot = margin_fc7 * mapping_label_onehot
diff = margin_fc7_onehot - fc7_onehot
fc7 = fc7 + diff
fc7 = fc7 * self.loss_s
return fc7, mapping_label_onehot