-
Notifications
You must be signed in to change notification settings - Fork 32
/
metrics.py
131 lines (115 loc) · 4.41 KB
/
metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Parameter
import math
class AdaCos(nn.Module):
def __init__(self, num_features, num_classes, m=0.50):
super(AdaCos, self).__init__()
self.num_features = num_features
self.n_classes = num_classes
self.s = math.sqrt(2) * math.log(num_classes - 1)
self.m = m
self.W = Parameter(torch.FloatTensor(num_classes, num_features))
nn.init.xavier_uniform_(self.W)
def forward(self, input, label=None):
# normalize features
x = F.normalize(input)
# normalize weights
W = F.normalize(self.W)
# dot product
logits = F.linear(x, W)
if label is None:
return logits
# feature re-scale
theta = torch.acos(torch.clamp(logits, -1.0 + 1e-7, 1.0 - 1e-7))
one_hot = torch.zeros_like(logits)
one_hot.scatter_(1, label.view(-1, 1).long(), 1)
with torch.no_grad():
B_avg = torch.where(one_hot < 1, torch.exp(self.s * logits), torch.zeros_like(logits))
B_avg = torch.sum(B_avg) / input.size(0)
# print(B_avg)
theta_med = torch.median(theta[one_hot == 1])
self.s = torch.log(B_avg) / torch.cos(torch.min(math.pi/4 * torch.ones_like(theta_med), theta_med))
output = self.s * logits
return output
class ArcFace(nn.Module):
def __init__(self, num_features, num_classes, s=30.0, m=0.50):
super(ArcFace, self).__init__()
self.num_features = num_features
self.n_classes = num_classes
self.s = s
self.m = m
self.W = Parameter(torch.FloatTensor(num_classes, num_features))
nn.init.xavier_uniform_(self.W)
def forward(self, input, label=None):
# normalize features
x = F.normalize(input)
# normalize weights
W = F.normalize(self.W)
# dot product
logits = F.linear(x, W)
if label is None:
return logits
# add margin
theta = torch.acos(torch.clamp(logits, -1.0 + 1e-7, 1.0 - 1e-7))
target_logits = torch.cos(theta + self.m)
one_hot = torch.zeros_like(logits)
one_hot.scatter_(1, label.view(-1, 1).long(), 1)
output = logits * (1 - one_hot) + target_logits * one_hot
# feature re-scale
output *= self.s
return output
class SphereFace(nn.Module):
def __init__(self, num_features, num_classes, s=30.0, m=1.35):
super(SphereFace, self).__init__()
self.num_features = num_features
self.n_classes = num_classes
self.s = s
self.m = m
self.W = Parameter(torch.FloatTensor(num_classes, num_features))
nn.init.xavier_uniform_(self.W)
def forward(self, input, label=None):
# normalize features
x = F.normalize(input)
# normalize weights
W = F.normalize(self.W)
# dot product
logits = F.linear(x, W)
if label is None:
return logits
# add margin
theta = torch.acos(torch.clamp(logits, -1.0 + 1e-7, 1.0 - 1e-7))
target_logits = torch.cos(self.m * theta)
one_hot = torch.zeros_like(logits)
one_hot.scatter_(1, label.view(-1, 1).long(), 1)
output = logits * (1 - one_hot) + target_logits * one_hot
# feature re-scale
output *= self.s
return output
class CosFace(nn.Module):
def __init__(self, num_features, num_classes, s=30.0, m=0.35):
super(CosFace, self).__init__()
self.num_features = num_features
self.n_classes = num_classes
self.s = s
self.m = m
self.W = Parameter(torch.FloatTensor(num_classes, num_features))
nn.init.xavier_uniform_(self.W)
def forward(self, input, label=None):
# normalize features
x = F.normalize(input)
# normalize weights
W = F.normalize(self.W)
# dot product
logits = F.linear(x, W)
if label is None:
return logits
# add margin
target_logits = logits - self.m
one_hot = torch.zeros_like(logits)
one_hot.scatter_(1, label.view(-1, 1).long(), 1)
output = logits * (1 - one_hot) + target_logits * one_hot
# feature re-scale
output *= self.s
return output