forked from zyyll/dual_encoding
-
Notifications
You must be signed in to change notification settings - Fork 0
/
loss.py
91 lines (74 loc) · 2.79 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import torch
from torch.autograd import Variable
import torch.nn as nn
def cosine_sim(im, s):
"""Cosine similarity between all the image and sentence pairs
"""
return im.mm(s.t())
def order_sim(im, s):
"""Order embeddings similarity measure $max(0, s-im)$
"""
YmX = (s.unsqueeze(1).expand(s.size(0), im.size(0), s.size(1))
- im.unsqueeze(0).expand(s.size(0), im.size(0), s.size(1)))
score = -YmX.clamp(min=0).pow(2).sum(2).sqrt().t()
return score
def euclidean_sim(im, s):
"""Order embeddings similarity measure $max(0, s-im)$
"""
YmX = (s.unsqueeze(1).expand(s.size(0), im.size(0), s.size(1))
- im.unsqueeze(0).expand(s.size(0), im.size(0), s.size(1)))
score = -YmX.pow(2).sum(2).t()
return score
class TripletLoss(nn.Module):
"""
triplet ranking loss
"""
def __init__(self, margin=0, measure=False, max_violation=False, cost_style='sum', direction='all'):
super(TripletLoss, self).__init__()
self.margin = margin
self.cost_style = cost_style
self.direction = direction
if measure == 'order':
self.sim = order_sim
elif measure == 'euclidean':
self.sim = euclidean_sim
else:
self.sim = cosine_sim
self.max_violation = max_violation
def forward(self, s, im):
# compute image-sentence score matrix
scores = self.sim(im, s)
diagonal = scores.diag().view(im.size(0), 1)
d1 = diagonal.expand_as(scores)
d2 = diagonal.t().expand_as(scores)
# clear diagonals
mask = torch.eye(scores.size(0)) > .5
I = Variable(mask)
if torch.cuda.is_available():
I = I.cuda()
cost_s = None
cost_im = None
# compare every diagonal score to scores in its column
if self.direction in ['i2t', 'all']:
# caption retrieval
cost_s = (self.margin + scores - d1).clamp(min=0)
cost_s = cost_s.masked_fill_(I, 0)
# compare every diagonal score to scores in its row
if self.direction in ['t2i', 'all']:
# image retrieval
cost_im = (self.margin + scores - d2).clamp(min=0)
cost_im = cost_im.masked_fill_(I, 0)
# keep the maximum violating negative for each query
if self.max_violation:
if cost_s is not None:
cost_s = cost_s.max(1)[0]
if cost_im is not None:
cost_im = cost_im.max(0)[0]
if cost_s is None:
cost_s = Variable(torch.zeros(1)).cuda()
if cost_im is None:
cost_im = Variable(torch.zeros(1)).cuda()
if self.cost_style == 'sum':
return cost_s.sum() + cost_im.sum()
else:
return cost_s.mean() + cost_im.mean()