-
Notifications
You must be signed in to change notification settings - Fork 12
/
loss_function.py
156 lines (115 loc) · 6.84 KB
/
loss_function.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import torch
def build_descriptor_loss(source_des, target_des, tar_points_un, top_kk=None, relax_field=4, eval_only=False):
"""
Desc Head Loss, per-pixel level triplet loss from https://arxiv.org/pdf/1902.11046.pdf.
Parameters
----------
source_des: torch.Tensor (B,256,H/8,W/8)
Source image descriptors.
target_des: torch.Tensor (B,256,H/8,W/8)
Target image descriptors.
source_points: torch.Tensor (B,H/8,W/8,2)
Source image keypoints
tar_points: torch.Tensor (B,H/8,W/8,2)
Target image keypoints
tar_points_un: torch.Tensor (B,2,H/8,W/8)
Target image keypoints unnormalized
eval_only: bool
Computes only recall without the loss.
Returns
-------
loss: torch.Tensor
Descriptor loss.
recall: torch.Tensor
Descriptor match recall.
"""
device = source_des.device
loss = 0
batch_size = source_des.size(0)
recall = 0.
relax_field_size = [relax_field]
margins = [1.0]
weights = [1.0]
isource_dense = top_kk is None
for b_id in range(batch_size):
if isource_dense:
ref_desc = source_des[b_id].squeeze().view(256, -1)
tar_desc = target_des[b_id].squeeze().view(256, -1)
tar_points_raw = tar_points_un[b_id].view(2, -1)
else:
top_k = top_kk[b_id].squeeze()
n_feat = top_k.sum().item()
if n_feat < 20:
continue
ref_desc = source_des[b_id].squeeze()[:, top_k]
tar_desc = target_des[b_id].squeeze()[:, top_k]
tar_points_raw = tar_points_un[b_id][:, top_k]
# Compute dense descriptor distance matrix and find nearest neighbor
ref_desc = ref_desc.div(torch.norm(ref_desc, p=2, dim=0))
tar_desc = tar_desc.div(torch.norm(tar_desc, p=2, dim=0))
dmat = torch.mm(ref_desc.t(), tar_desc)
dmat = torch.sqrt(2 - 2 * torch.clamp(dmat, min=-1, max=1))
_, idx = torch.sort(dmat, dim=1)
# Compute triplet loss and recall
for pyramid in range(len(relax_field_size)):
candidates = idx.t()
match_k_x = tar_points_raw[0, candidates]
match_k_y = tar_points_raw[1, candidates]
tru_x = tar_points_raw[0]
tru_y = tar_points_raw[1]
if pyramid == 0:
correct2 = (abs(match_k_x[0]-tru_x) == 0) & (abs(match_k_y[0]-tru_y) == 0)
correct2_cnt = correct2.float().sum()
recall += float(1.0 / batch_size) * (float(correct2_cnt) / float( ref_desc.size(1)))
if eval_only:
continue
correct_k = (abs(match_k_x - tru_x) <= relax_field_size[pyramid]) & (abs(match_k_y - tru_y) <= relax_field_size[pyramid])
incorrect_index = torch.arange(start=correct_k.shape[0]-1, end=-1, step=-1).unsqueeze(1).repeat(1,correct_k.shape[1]).to(device)
incorrect_first = torch.argmax(incorrect_index * (1 - correct_k.long()), dim=0)
incorrect_first_index = candidates.gather(0, incorrect_first.unsqueeze(0)).squeeze()
anchor_var = ref_desc
posource_var = tar_desc
neg_var = tar_desc[:, incorrect_first_index]
loss += float(1.0 / batch_size) * torch.nn.functional.triplet_margin_loss(anchor_var.t(), posource_var.t(), neg_var.t(), margin=margins[pyramid]).mul(weights[pyramid])
return loss, recall
class KeypointLoss(object):
"""
Loss function class encapsulating the location loss, the descriptor loss, and the score loss.
"""
def __init__(self, config):
self.score_weight = config.score_weight
self.loc_weight = config.loc_weight
self.desc_weight = config.desc_weight
self.corres_weight = config.corres_weight
self.corres_threshold = config.corres_threshold
def __call__(self, data):
B, _, hc, wc = data['source_score'].shape
loc_mat_abs = torch.abs(data['target_coord_warped'].view(B, 2, -1).unsqueeze(3) - data['target_coord'].view(B, 2, -1).unsqueeze(2))
l2_dist_loc_mat = torch.norm(loc_mat_abs, p=2, dim=1)
l2_dist_loc_min, l2_dist_loc_min_index = l2_dist_loc_mat.min(dim=2)
# construct pseudo ground truth matching matrix
loc_min_mat = torch.repeat_interleave(l2_dist_loc_min.unsqueeze(dim=-1), repeats=l2_dist_loc_mat.shape[-1], dim=-1)
pos_mask = l2_dist_loc_mat.eq(loc_min_mat) & l2_dist_loc_mat.le(1.)
neg_mask = l2_dist_loc_mat.ge(4.)
pos_corres = - torch.log(data['confidence_matrix'][pos_mask])
neg_corres = - torch.log(1.0 - data['confidence_matrix'][neg_mask])
corres_loss = pos_corres.mean() + 5e5 * neg_corres.mean()
# corresponding distance threshold is 4
dist_norm_valid_mask = l2_dist_loc_min.lt(self.corres_threshold) & data['border_mask'].view(B, hc * wc)
# location loss
loc_loss = l2_dist_loc_min[dist_norm_valid_mask].mean()
# desc Head Loss, per-pixel level triplet loss from https://arxiv.org/pdf/1902.11046.pdf.
desc_loss, _ = build_descriptor_loss(data['source_desc'], data['target_desc_warped'], data['target_coord_warped'].detach(), top_kk=data['border_mask'], relax_field=8)
# score loss
target_score_associated = data['target_score'].view(B, hc * wc).gather(1, l2_dist_loc_min_index).view(B, hc, wc).unsqueeze(1)
dist_norm_valid_mask = dist_norm_valid_mask.view(B, hc, wc).unsqueeze(1) & data['border_mask'].unsqueeze(1)
l2_dist_loc_min = l2_dist_loc_min.view(B, hc, wc).unsqueeze(1)
loc_err = l2_dist_loc_min[dist_norm_valid_mask]
# repeatable_constrain in score loss
repeatable_constrain = ((target_score_associated[dist_norm_valid_mask] + data['source_score'][dist_norm_valid_mask]) * (loc_err - loc_err.mean())).mean()
# consistent_constrain in score_loss
consistent_constrain = torch.nn.functional.mse_loss(data['target_score_warped'][data['border_mask'].unsqueeze(1)], data['source_score'][data['border_mask'].unsqueeze(1)]).mean() * 2
aware_consistent_loss = torch.nn.functional.mse_loss(data['target_aware_warped'][data['border_mask'].unsqueeze(1).repeat(1, 2, 1, 1)], data['source_aware'][data['border_mask'].unsqueeze(1).repeat(1, 2, 1, 1)]).mean() * 2
score_loss = repeatable_constrain + consistent_constrain + aware_consistent_loss
loss = self.loc_weight * loc_loss + self.desc_weight * desc_loss + self.score_weight * score_loss + self.corres_weight * corres_loss
return loss, self.loc_weight * loc_loss, self.desc_weight * desc_loss, self.score_weight * score_loss, self.corres_weight * corres_loss