forked from visionml/pytracking
-
Notifications
You must be signed in to change notification settings - Fork 0
/
processing.py
322 lines (267 loc) · 15.6 KB
/
processing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
import torch
import torchvision.transforms as transforms
from pytracking import TensorDict
import ltr.data.processing_utils as prutils
class BaseProcessing:
""" Base class for Processing. Processing class is used to process the data returned by a dataset, before passing it
through the network. For example, it can be used to crop a search region around the object, apply various data
augmentations, etc."""
def __init__(self, transform=transforms.ToTensor(), train_transform=None, test_transform=None, joint_transform=None):
"""
args:
transform - The set of transformations to be applied on the images. Used only if train_transform or
test_transform is None.
train_transform - The set of transformations to be applied on the train images. If None, the 'transform'
argument is used instead.
test_transform - The set of transformations to be applied on the test images. If None, the 'transform'
argument is used instead.
joint_transform - The set of transformations to be applied 'jointly' on the train and test images. For
example, it can be used to convert both test and train images to grayscale.
"""
self.transform = {'train': transform if train_transform is None else train_transform,
'test': transform if test_transform is None else test_transform,
'joint': joint_transform}
def __call__(self, data: TensorDict):
raise NotImplementedError
class ATOMProcessing(BaseProcessing):
""" The processing class used for training ATOM. The images are processed in the following way.
First, the target bounding box is jittered by adding some noise. Next, a square region (called search region )
centered at the jittered target center, and of area search_area_factor^2 times the area of the jittered box is
cropped from the image. The reason for jittering the target box is to avoid learning the bias that the target is
always at the center of the search region. The search region is then resized to a fixed size given by the
argument output_sz. A set of proposals are then generated for the test images by jittering the ground truth box.
"""
def __init__(self, search_area_factor, output_sz, center_jitter_factor, scale_jitter_factor, proposal_params,
mode='pair', *args, **kwargs):
"""
args:
search_area_factor - The size of the search region relative to the target size.
output_sz - An integer, denoting the size to which the search region is resized. The search region is always
square.
center_jitter_factor - A dict containing the amount of jittering to be applied to the target center before
extracting the search region. See _get_jittered_box for how the jittering is done.
scale_jitter_factor - A dict containing the amount of jittering to be applied to the target size before
extracting the search region. See _get_jittered_box for how the jittering is done.
proposal_params - Arguments for the proposal generation process. See _generate_proposals for details.
mode - Either 'pair' or 'sequence'. If mode='sequence', then output has an extra dimension for frames
"""
super().__init__(*args, **kwargs)
self.search_area_factor = search_area_factor
self.output_sz = output_sz
self.center_jitter_factor = center_jitter_factor
self.scale_jitter_factor = scale_jitter_factor
self.proposal_params = proposal_params
self.mode = mode
def _get_jittered_box(self, box, mode):
""" Jitter the input box
args:
box - input bounding box
mode - string 'train' or 'test' indicating train or test data
returns:
torch.Tensor - jittered box
"""
jittered_size = box[2:4] * torch.exp(torch.randn(2) * self.scale_jitter_factor[mode])
max_offset = (jittered_size.prod().sqrt() * self.center_jitter_factor[mode]).item()
jittered_center = box[0:2] + 0.5 * box[2:4] + max_offset * (torch.rand(2) - 0.5)
return torch.cat((jittered_center - 0.5 * jittered_size, jittered_size), dim=0)
def _generate_proposals(self, box):
""" Generates proposals by adding noise to the input box
args:
box - input box
returns:
torch.Tensor - Array of shape (num_proposals, 4) containing proposals
torch.Tensor - Array of shape (num_proposals,) containing IoU overlap of each proposal with the input box. The
IoU is mapped to [-1, 1]
"""
# Generate proposals
num_proposals = self.proposal_params['boxes_per_frame']
proposals = torch.zeros((num_proposals, 4))
gt_iou = torch.zeros(num_proposals)
for i in range(num_proposals):
proposals[i, :], gt_iou[i] = prutils.perturb_box(box, min_iou=self.proposal_params['min_iou'],
sigma_factor=self.proposal_params['sigma_factor']
)
# Map to [-1, 1]
gt_iou = gt_iou * 2 - 1
return proposals, gt_iou
def __call__(self, data: TensorDict):
"""
args:
data - The input data, should contain the following fields:
'train_images' -
'test_images' -
'train_anno' -
'test_anno' -
returns:
TensorDict - output data block with following fields:
'train_images' -
'test_images' -
'train_anno' -
'test_anno' -
'test_proposals'-
'proposal_iou' -
"""
# Apply joint transforms
if self.transform['joint'] is not None:
num_train_images = len(data['train_images'])
all_images = data['train_images'] + data['test_images']
all_images_trans = self.transform['joint'](*all_images)
data['train_images'] = all_images_trans[:num_train_images]
data['test_images'] = all_images_trans[num_train_images:]
for s in ['train', 'test']:
assert self.mode == 'sequence' or len(data[s + '_images']) == 1, \
"In pair mode, num train/test frames must be 1"
# Add a uniform noise to the center pos
jittered_anno = [self._get_jittered_box(a, s) for a in data[s + '_anno']]
# Crop image region centered at jittered_anno box
crops, boxes = prutils.jittered_center_crop(data[s + '_images'], jittered_anno, data[s + '_anno'],
self.search_area_factor, self.output_sz)
# Apply transforms
data[s + '_images'] = [self.transform[s](x) for x in crops]
data[s + '_anno'] = boxes
# Generate proposals
frame2_proposals, gt_iou = zip(*[self._generate_proposals(a) for a in data['test_anno']])
data['test_proposals'] = list(frame2_proposals)
data['proposal_iou'] = list(gt_iou)
# Prepare output
if self.mode == 'sequence':
data = data.apply(prutils.stack_tensors)
else:
data = data.apply(lambda x: x[0] if isinstance(x, list) else x)
return data
class DiMPProcessing(BaseProcessing):
""" The processing class used for training DiMP. The images are processed in the following way.
First, the target bounding box is jittered by adding some noise. Next, a square region (called search region )
centered at the jittered target center, and of area search_area_factor^2 times the area of the jittered box is
cropped from the image. The reason for jittering the target box is to avoid learning the bias that the target is
always at the center of the search region. The search region is then resized to a fixed size given by the
argument output_sz. A Gaussian label centered at the target is generated for each image. These label functions are
used for computing the loss of the predicted classification model on the test images. A set of proposals are
also generated for the test images by jittering the ground truth box. These proposals are used to train the
bounding box estimating branch.
"""
def __init__(self, search_area_factor, output_sz, center_jitter_factor, scale_jitter_factor, crop_type='replicate',
mode='pair',
proposal_params=None, label_function_params=None, *args, **kwargs):
"""
args:
search_area_factor - The size of the search region relative to the target size.
output_sz - An integer, denoting the size to which the search region is resized. The search region is always
square.
center_jitter_factor - A dict containing the amount of jittering to be applied to the target center before
extracting the search region. See _get_jittered_box for how the jittering is done.
scale_jitter_factor - A dict containing the amount of jittering to be applied to the target size before
extracting the search region. See _get_jittered_box for how the jittering is done.
crop_type - If 'replicate', the boundary pixels are replicated in case the search region crop goes out of image.
If 'nopad', the search region crop is shifted/shrunk to fit completely inside the image.
mode - Either 'pair' or 'sequence'. If mode='sequence', then output has an extra dimension for frames
proposal_params - Arguments for the proposal generation process. See _generate_proposals for details.
label_function_params - Arguments for the label generation process. See _generate_label_function for details.
"""
super().__init__(*args, **kwargs)
self.search_area_factor = search_area_factor
self.output_sz = output_sz
self.center_jitter_factor = center_jitter_factor
self.scale_jitter_factor = scale_jitter_factor
self.crop_type = crop_type
self.mode = mode
self.proposal_params = proposal_params
self.label_function_params = label_function_params
def _get_jittered_box(self, box, mode):
""" Jitter the input box
args:
box - input bounding box
mode - string 'train' or 'test' indicating train or test data
returns:
torch.Tensor - jittered box
"""
jittered_size = box[2:4] * torch.exp(torch.randn(2) * self.scale_jitter_factor[mode])
max_offset = (jittered_size.prod().sqrt() * self.center_jitter_factor[mode]).item()
jittered_center = box[0:2] + 0.5 * box[2:4] + max_offset * (torch.rand(2) - 0.5)
return torch.cat((jittered_center - 0.5 * jittered_size, jittered_size), dim=0)
def _generate_proposals(self, box):
""" Generates proposals by adding noise to the input box
args:
box - input box
returns:
torch.Tensor - Array of shape (num_proposals, 4) containing proposals
torch.Tensor - Array of shape (num_proposals,) containing IoU overlap of each proposal with the input box. The
IoU is mapped to [-1, 1]
"""
# Generate proposals
num_proposals = self.proposal_params['boxes_per_frame']
proposals = torch.zeros((num_proposals, 4))
gt_iou = torch.zeros(num_proposals)
for i in range(num_proposals):
proposals[i, :], gt_iou[i] = prutils.perturb_box(box, min_iou=self.proposal_params['min_iou'],
sigma_factor=self.proposal_params['sigma_factor'])
# Map to [-1, 1]
gt_iou = gt_iou * 2 - 1
return proposals, gt_iou
def _generate_label_function(self, target_bb):
""" Generates the gaussian label function centered at target_bb
args:
target_bb - target bounding box (num_images, 4)
returns:
torch.Tensor - Tensor of shape (num_images, label_sz, label_sz) containing the label for each sample
"""
gauss_label = prutils.gaussian_label_function(target_bb.view(-1, 4), self.label_function_params['sigma_factor'],
self.label_function_params['kernel_sz'],
self.label_function_params['feature_sz'], self.output_sz,
end_pad_if_even=self.label_function_params.get('end_pad_if_even', True))
return gauss_label
def __call__(self, data: TensorDict):
"""
args:
data - The input data, should contain the following fields:
'train_images' -
'test_images' -
'train_anno' -
'test_anno' -
returns:
TensorDict - output data block with following fields:
'train_images' -
'test_images' -
'train_anno' -
'test_anno' -
'test_proposals' (optional) -
'proposal_iou' (optional) -
'test_label' (optional) -
'train_label' (optional) -
"""
if self.transform['joint'] is not None:
num_train_images = len(data['train_images'])
all_images = data['train_images'] + data['test_images']
all_images_trans = self.transform['joint'](*all_images)
data['train_images'] = all_images_trans[:num_train_images]
data['test_images'] = all_images_trans[num_train_images:]
for s in ['train', 'test']:
assert self.mode == 'sequence' or len(data[s + '_images']) == 1, \
"In pair mode, num train/test frames must be 1"
# Add a uniform noise to the center pos
jittered_anno = [self._get_jittered_box(a, s) for a in data[s + '_anno']]
if self.crop_type == 'replicate':
crops, boxes = prutils.jittered_center_crop(data[s + '_images'], jittered_anno, data[s + '_anno'],
self.search_area_factor, self.output_sz)
elif self.crop_type == 'nopad':
crops, boxes = prutils.jittered_center_crop_nopad(data[s + '_images'], jittered_anno, data[s + '_anno'],
self.search_area_factor, self.output_sz)
else:
raise ValueError('Unknown crop type {}'.format(self.crop_type))
data[s + '_images'] = [self.transform[s](x) for x in crops]
data[s + '_anno'] = boxes
# Generate proposals
if self.proposal_params:
frame2_proposals, gt_iou = zip(*[self._generate_proposals(a) for a in data['test_anno']])
data['test_proposals'] = list(frame2_proposals)
data['proposal_iou'] = list(gt_iou)
# Prepare output
if self.mode == 'sequence':
data = data.apply(prutils.stack_tensors)
else:
data = data.apply(lambda x: x[0] if isinstance(x, list) else x)
# Generate label functions
if self.label_function_params is not None:
data['train_label'] = self._generate_label_function(data['train_anno'])
data['test_label'] = self._generate_label_function(data['test_anno'])
return data