-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathhard_mining.py
156 lines (138 loc) · 5.56 KB
/
hard_mining.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
from __future__ import print_function
import argparse
import os
import shutil
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
import torch.backends.cudnn as cudnn
import numpy as np
from random import shuffle
"""
Base class for sampling.
"""
class TripletSampler(object):
def __init__(self, num_classes, num_samples):
self.num_classes = int(num_classes)
self.num_samples = int(num_samples)
"""
Reset all samples, that is, clear all mined hard examples.
"""
def Reset(self):
pass
"""
Call this after every batch to generate a list of hard negatives.
"""
def SampleNegatives(self, dista, distb, triplet_loss, ids):
print("Implement me!!")
pass
"""
Call this when regenerating list of triplets, to get a set of negative pairs.
"""
# TODO: may want to update the signature.
def ChooseNegatives(self, num):
print("Implement me!!")
pass
"""
Get N hardest:
Every time SampleNegatives is invoked, this chooses the
triplets with N hardest negatives from a set of already constructred triplets,
such as the triplets used in training/
When ChooseNegatives is invoked, this returns a set of hard negatives from
triplets that it has seen before.
"""
class NHardestTripletSampler(TripletSampler):
def __init__(self, num_classes, num_samples):
super(NHardestTripletSampler, self).__init__(num_classes, num_samples)
self.negatives = []
self.dist_neg = []
def Reset(self):
self.negatives = []
self.dist_neg = []
"""
Negatives with least distb.
"""
def SampleNegatives(self, dista, distb, triplet_loss, ids):
distb = distb.cpu()
assert(self.num_samples <= dista.size()[0])
idx1, idx2, idx3 = ids
# sort by distance between anchor and negative
sortd, indices = torch.sort(distb, descending=False, dim=0)
sel_indices = indices[0:self.num_samples].data.numpy().reshape((self.num_samples))
anchor = idx1.numpy()[sel_indices].reshape((self.num_samples))
negs = idx3.numpy()[sel_indices].reshape((self.num_samples))
self.negatives += zip(anchor, negs)
self.dist_neg += list(sortd.data.numpy()[sel_indices].reshape((self.num_samples)))
"""
Now get some triplets for regenerating triplets.
"""
def ChooseNegatives(self, num):
l = len(self.negatives)
assert(l >= num)
# sort by distance between anchor and negative
sorted_indices = np.argsort(self.dist_neg)
sel_indices = sorted_indices[0:num]
return ([self.negatives[i] for i in sel_indices])
"""
Semihard sampler -- selects examples where distance between anchor and negative
is less than the distance between anchor and positive.
"""
class SemiHardTripletSampler(TripletSampler):
def __init__(self, num_classes, num_samples):
super(SemiHardTripletSampler, self).__init__(num_classes, num_samples)
self.negatives = []
def Reset(self):
self.negatives = []
"""
Negatives with distb < dista.
"""
def SampleNegatives(self, dista, distb, triplet_loss, ids):
dista = dista.cpu()
distb = distb.cpu()
assert(self.num_samples <= dista.size()[0])
idx1, idx2, idx3 = ids
# select examples where distance to negative is less than distance to positive
sel_indices = np.where(distb.data.numpy() < dista.data.numpy())[0]
sel_indices = np.random.choice(sel_indices, self.num_samples)
anchor = idx1.numpy()[sel_indices].reshape((self.num_samples))
negs = idx3.numpy()[sel_indices].reshape((self.num_samples))
self.negatives += zip(anchor, negs)
"""
Now get some triplets for regenerating triplets.
"""
def ChooseNegatives(self, num):
sel_indices = np.random.choice(range(len(self.negatives)), num)
return ([self.negatives[i] for i in sel_indices])
"""
Classification based sampler.
"""
class ClassificationBasedSampler(object):
def __init__(self, num_classes, num_samples):
self.num_classes = int(num_classes)
self.num_samples = int(num_samples)
self.negatives = []
def Reset(self):
self.negatives = []
def SampleNegatives(self, labels_true, labels_pred):
neg_indices = np.where(labels_true != labels_pred)[0]
true_classes = labels_true[neg_indices]
pred_classes = labels_pred[neg_indices] # what cluster does this point falsely belong to?
cor_indices = np.where(labels_true == labels_pred)[0] # some points correctly clustered
cor_classes = labels_true[cor_indices]
anchor_candidates = dict()
for c in np.unique(cor_classes):
subset_indices = np.where(cor_classes == c)[0]
anchor_candidates[c] = cor_indices[subset_indices]
# now get an anchor point (a correct point correctly there in that cluster)
for i in np.random.permutation(len(neg_indices)):
pred_class = pred_classes[i] # predicted class for incorrectly classified point
if pred_class in anchor_candidates.keys(): # there is an anchor point, a point correctly sent to that class
self.negatives.append((np.random.choice(anchor_candidates[pred_class]), neg_indices[i]))
if len(self.negatives) == self.num_samples:
break
def ChooseNegatives(self, num):
sel_indices = np.random.choice(range(len(self.negatives)), num)
return ([self.negatives[i] for i in sel_indices])