-
Notifications
You must be signed in to change notification settings - Fork 9
/
fyl_pytorch.py
174 lines (120 loc) · 4.39 KB
/
fyl_pytorch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
# Author: Mathieu Blondel
# License: Simplified BSD
"""
PyTorch implementation of
Learning Classifiers with Fenchel-Young Losses:
Generalized Entropies, Margins, and Algorithms.
Mathieu Blondel, André F. T. Martins, Vlad Niculae.
https://arxiv.org/abs/1805.09717
"""
import torch
class ConjugateFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, theta, grad, Omega):
ctx.save_for_backward(grad)
return torch.sum(theta * grad, dim=1) - Omega(grad)
@staticmethod
def backward(ctx, grad_output):
grad, = ctx.saved_tensors
return grad * grad_output.view(-1, 1), None, None
class FYLoss(torch.nn.Module):
def __init__(self, weights="average"):
self.weights = weights
super(FYLoss, self).__init__()
def forward(self, theta, y_true):
self.y_pred = self.predict(theta)
ret = ConjugateFunction.apply(theta, self.y_pred, self.Omega)
if len(y_true.shape) == 2:
# y_true contains label proportions
ret += self.Omega(y_true)
ret -= torch.sum(y_true * theta, dim=1)
elif len(y_true.shape) == 1:
# y_true contains label integers (0, ..., n_classes-1)
if y_true.dtype != torch.long:
raise ValueError("y_true should contains long integers.")
all_rows = torch.arange(y_true.shape[0])
ret -= theta[all_rows, y_true]
else:
raise ValueError("Invalid shape for y_true.")
if self.weights == "average":
return torch.mean(ret)
else:
return torch.sum(ret)
class SquaredLoss(FYLoss):
def Omega(self, mu):
return 0.5 * torch.sum((mu ** 2), dim=1)
def predict(self, theta):
return theta
class PerceptronLoss(FYLoss):
def predict(self, theta):
ret = torch.zeros_like(theta)
all_rows = torch.arange(theta.shape[0])
ret[all_rows, torch.argmax(theta, dim=1)] = 1
return ret
def Omega(self, theta):
return 0
def Shannon_negentropy(p, dim):
tmp = torch.zeros_like(p)
mask = p > 0
tmp[mask] = p[mask] * torch.log(p[mask])
return torch.sum(tmp, dim)
class LogisticLoss(FYLoss):
def predict(self, theta):
return torch.nn.Softmax(dim=1)(theta)
def Omega(self, p):
return Shannon_negentropy(p, dim=1)
class Logistic_OVA_Loss(FYLoss):
def predict(self, theta):
return torch.nn.Sigmoid()(theta)
def Omega(self, p):
return Shannon_negentropy(p, dim=1) + Shannon_negentropy(1 - p, dim=1)
# begin: From OpenNMT-py
def threshold_and_support(z, dim=0):
"""
z: any dimension
dim: dimension along which to apply the sparsemax
"""
sorted_z, _ = torch.sort(z, descending=True, dim=dim)
z_sum = sorted_z.cumsum(dim) - 1 # sort of a misnomer
k = torch.arange(1, sorted_z.size(dim) + 1, device=z.device).type(z.dtype).view(
torch.Size([-1] + [1] * (z.dim() - 1))
).transpose(0, dim)
support = k * sorted_z > z_sum
k_z_indices = support.sum(dim=dim).unsqueeze(dim)
k_z = k_z_indices.type(z.dtype)
tau_z = z_sum.gather(dim, k_z_indices - 1) / k_z
return tau_z, k_z
class SparsemaxFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, dim=0):
"""
input (FloatTensor): any shape
returns (FloatTensor): same shape with sparsemax computed on given dim
"""
ctx.dim = dim
tau_z, k_z = threshold_and_support(input, dim=dim)
output = torch.clamp(input - tau_z, min=0)
ctx.save_for_backward(k_z, output)
return output
@staticmethod
def backward(ctx, grad_output):
k_z, output = ctx.saved_tensors
dim = ctx.dim
grad_input = grad_output.clone()
grad_input[output == 0] = 0
v_hat = (grad_input.sum(dim=dim) / k_z.squeeze()).unsqueeze(dim)
grad_input = torch.where(output != 0, grad_input - v_hat, grad_input)
return grad_input, None
sparsemax = SparsemaxFunction.apply
class Sparsemax(torch.nn.Module):
def __init__(self, dim=0):
self.dim = dim
super(Sparsemax, self).__init__()
def forward(self, input):
return sparsemax(input, self.dim)
# end: From OpenNMT-py
class SparsemaxLoss(FYLoss):
def predict(self, theta):
return Sparsemax(dim=1)(theta)
def Omega(self, p):
return 0.5 * torch.sum((p ** 2), dim=1) - 0.5