Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

video box2seg #275

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 83 additions & 0 deletions ltr/actors/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,3 +138,86 @@ def __call__(self, data):
stats['Stats/acc_box_train'] = acc_box/cnt_box

return loss, stats


class STAActor(BaseActor):
"""Actor for training the DiMP network."""
def __init__(self, net, objective, loss_weight=None,
use_focal_loss=False, use_lovasz_loss=False,
detach_pred=True,
num_refinement_iter=3,
disable_backbone_bn=False,
disable_all_bn=False):
super().__init__(net, objective)
if loss_weight is None:
loss_weight = {'segm': 1.0}
self.loss_weight = loss_weight

self.use_focal_loss = use_focal_loss
self.use_lovasz_loss = use_lovasz_loss
self.detach_pred = detach_pred
self.num_refinement_iter = num_refinement_iter
self.disable_backbone_bn = disable_backbone_bn
self.disable_all_bn = disable_all_bn

def train(self, mode=True):
""" Set whether the network is in train mode.
args:
mode (True) - Bool specifying whether in training mode.
"""
self.net.train(mode)

if self.disable_all_bn:
self.net.eval()
elif self.disable_backbone_bn:
for m in self.net.feature_extractor.modules():
if isinstance(m, nn.BatchNorm2d):
m.eval()

def __call__(self, data):
"""
args:
data - The input data, should contain the fields 'train_images', 'test_images', 'train_anno',
'test_proposals', 'proposal_iou', 'test_label', 'train_masks' and 'test_masks'
returns:
loss - the training loss
stats - dict containing detailed losses
"""
segm_pred_bbox, segm_pred = self.net(train_imgs=data['train_images'],
train_bbox=data['train_anno'])
acc = 0
cnt = 0
acc_mid = 0
cnt_mid = 0

segm_pred_bbox = segm_pred_bbox.view(-1, 1, *segm_pred_bbox.shape[-2:])
segm_pred = segm_pred.view(-1, 1, *segm_pred.shape[-2:])
gt_segm = data['train_masks']
gt_segm = gt_segm.view(-1, 1, *gt_segm.shape[-2:])

loss_segm_bbox = self.loss_weight['segm'] * self.objective['segm'](segm_pred_bbox, gt_segm)
loss_segm = self.loss_weight['segm'] * self.objective['segm'](segm_pred, gt_segm)

acc_l = [davis_jaccard_measure(torch.sigmoid(rm.detach()).cpu().numpy() > 0.5, lb.cpu().numpy()) for
rm, lb in zip(segm_pred.view(-1, *segm_pred.shape[-2:]), gt_segm.view(-1, *segm_pred.shape[-2:]))]
acc += sum(acc_l)
cnt += len(acc_l)

acc_l_mid = [davis_jaccard_measure(torch.sigmoid(rm.detach()).cpu().numpy() > 0.5, lb.cpu().numpy()) for
rm, lb in zip(segm_pred_bbox.view(-1, *segm_pred_bbox.shape[-2:]), gt_segm.view(-1, *segm_pred_bbox.shape[-2:]))]
acc_mid += sum(acc_l_mid)
cnt_mid += len(acc_l_mid)

loss = loss_segm_bbox + loss_segm

if torch.isinf(loss) or torch.isnan(loss):
raise Exception('ERROR: Loss was nan or inf!!!')

# Log stats
stats = {'Loss/total': loss.item()}
stats['Loss/segm mid'] = loss_segm_bbox.item()
stats['Loss/segm'] = loss_segm.item()

stats['Stats/acc_mid'] = acc_mid / cnt_mid
stats['Stats/acc'] = acc / cnt
return loss, stats
139 changes: 139 additions & 0 deletions ltr/data/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -776,6 +776,145 @@ def __call__(self, data: TensorDict):
return data


class STAProcessing(BaseProcessing):
""" The processing class used for training DiMP. The images are processed in the following way.
First, the target bounding box is jittered by adding some noise. Next, a square region (called search region )
centered at the jittered target center, and of area search_area_factor^2 times the area of the jittered box is
cropped from the image. The reason for jittering the target box is to avoid learning the bias that the target is
always at the center of the search region. The search region is then resized to a fixed size given by the
argument output_sz.

"""

def __init__(self, search_area_factor, output_sz, center_jitter_factor, scale_jitter_factor, crop_type='replicate',
max_scale_change=None, mode='pair',
new_roll=False, *args, **kwargs):
"""
args:
search_area_factor - The size of the search region relative to the target size.
output_sz - An integer, denoting the size to which the search region is resized. The search region is always
square.
center_jitter_factor - A dict containing the amount of jittering to be applied to the target center before
extracting the search region. See _get_jittered_box for how the jittering is done.
scale_jitter_factor - A dict containing the amount of jittering to be applied to the target size before
extracting the search region. See _get_jittered_box for how the jittering is done.
crop_type - If 'replicate', the boundary pixels are replicated in case the search region crop goes out of image.
If 'nopad', the search region crop is shifted/shrunk to fit completely inside the image.
mode - Either 'pair' or 'sequence'. If mode='sequence', then output has an extra dimension for frames
"""
super().__init__(*args, **kwargs)
self.search_area_factor = search_area_factor
self.output_sz = output_sz
self.center_jitter_factor = center_jitter_factor
self.scale_jitter_factor = scale_jitter_factor
self.crop_type = crop_type
self.mode = mode
self.max_scale_change = max_scale_change

self.new_roll = new_roll

def _get_jittered_box(self, box, mode):
""" Jitter the input box
args:
box - input bounding box
mode - string 'train' or 'test' indicating train or test data

returns:
torch.Tensor - jittered box
"""

if self.scale_jitter_factor.get('mode', 'gauss') == 'gauss':
jittered_size = box[2:4] * torch.exp(torch.randn(2) * self.scale_jitter_factor[mode])
elif self.scale_jitter_factor.get('mode', 'gauss') == 'uniform':
jittered_size = box[2:4] * torch.exp(torch.FloatTensor(2).uniform_(-self.scale_jitter_factor[mode],
self.scale_jitter_factor[mode]))
else:
raise Exception

max_offset = (jittered_size.prod().sqrt() * torch.tensor(self.center_jitter_factor[mode])).float()
jittered_center = box[0:2] + 0.5 * box[2:4] + max_offset * (torch.rand(2) - 0.5)

return torch.cat((jittered_center - 0.5 * jittered_size, jittered_size), dim=0)

def _generate_search_bb(self, boxes_crop, crops, boxes_orig, boxes_jittered):
search_bb = []
anno_search_bb = []
for b_crop, im, b_orig, b_jit in zip(boxes_crop, crops, boxes_orig, boxes_jittered):
output_sz = self.output_sz
if isinstance(output_sz, (float, int)):
output_sz = (output_sz, output_sz)

output_sz = torch.Tensor(output_sz)

resize_factor = b_crop[-1] / b_orig[-1]

b_jit_crop_sz = b_jit[2:] * resize_factor

search_bb_sz = (
output_sz * (b_jit_crop_sz.prod() / output_sz.prod()).sqrt() * self.search_area_factor).ceil()
search_bb.append(torch.cat((torch.zeros(2), search_bb_sz)))

b_sh = b_crop.clone()

anno_search_bb.append(b_sh)
return search_bb, anno_search_bb

def __call__(self, data: TensorDict):
"""
args:
data - The input data, should contain the following fields:
'train_images' -
'train_masks' -
'train_anno' -

returns:
TensorDict - output data block with following fields:
'train_images' -
'train_masks' -
'train_anno' -
"""

if self.transform['joint'] is not None:
data['train_images'], data['train_anno'], data['train_masks'] = self.transform['joint'](image=data['train_images'], bbox=data['train_anno'], mask=data['train_masks'])

for s in ['train']:
assert self.mode == 'sequence' or len(data[s + '_images']) == 1, \
"In pair mode, num train/test frames must be 1"

# Add a uniform noise to the center pos
jittered_anno = [self._get_jittered_box(a, s) for a in data[s + '_anno']]
orig_anno = data[s + '_anno']

crops, boxes, mask_crops = prutils.target_image_crop(data[s + '_images'], jittered_anno,
data[s + '_anno'], self.search_area_factor,
self.output_sz, mode=self.crop_type,
max_scale_change=self.max_scale_change,
masks=data[s + '_masks'])

data[s + '_images'], data[s + '_anno'], data[s + '_masks'] = self.transform[s](image=crops, bbox=boxes, mask=mask_crops, joint=False)
# Generate search_bb
sa_bb, anno_in_sa = self._generate_search_bb(boxes, crops, orig_anno, jittered_anno)

data[s + '_sa_bb'] = sa_bb
data[s + '_anno_in_sa'] = anno_in_sa

for s in ['train']:
is_distractor = data.get('is_distractor_{}_frame'.format(s), None)
if is_distractor is not None:
for is_dist, box in zip(is_distractor, data[s+'_anno']):
if is_dist:
box[0] = 99999999.9
box[1] = 99999999.9

# Prepare output
if self.mode == 'sequence':
data = data.apply(stack_tensors)
else:
data = data.apply(lambda x: x[0] if isinstance(x, list) else x)

return data


class KYSProcessing(BaseProcessing):
""" The processing class used for training KYS. The images are processed in the following way.
First, the target bounding box is jittered by adding some noise. Next, a square region (called search region )
Expand Down
159 changes: 159 additions & 0 deletions ltr/data/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,165 @@ def __getitem__(self, index):
return self.processing(data)


class STASampler(torch.utils.data.Dataset):
""" Class responsible for sampling frames from training sequences to form batches. Each training sample is a
tuple consisting of i) a set of train frames, used to learn the DiMP classification model and obtain the
modulation vector for IoU-Net, and ii) a set of test frames on which target classification loss for the predicted
DiMP model, and the IoU prediction loss for the IoU-Net is calculated.

The sampling is done in the following ways. First a dataset is selected at random. Next, a sequence is selected
from that dataset. A base frame is then sampled randomly from the sequence. Next, a set of 'train frames' and
'test frames' are sampled from the sequence from the range [base_frame_id - max_gap, base_frame_id] and
(base_frame_id, base_frame_id + max_gap] respectively. Only the frames in which the target is visible are sampled.
If enough visible frames are not found, the 'max_gap' is increased gradually till enough frames are found.

The sampled frames are then passed through the input 'processing' function for the necessary processing-
"""

def __init__(self, datasets, p_datasets, samples_per_epoch, max_gap,
num_train_frames=1, processing=no_processing, p_reverse=None):
"""
args:
datasets - List of datasets to be used for training
p_datasets - List containing the probabilities by which each dataset will be sampled
samples_per_epoch - Number of training samples per epoch
max_gap - Maximum gap, in frame numbers, between the train frames and the test frames.
num_test_frames - Number of test frames to sample.
num_train_frames - Number of train frames to sample.
processing - An instance of Processing class which performs the necessary processing of the data.
"""
self.datasets = datasets

# If p not provided, sample uniformly from all videos
if p_datasets is None:
p_datasets = [len(d) for d in self.datasets]

# Normalize
p_total = sum(p_datasets)
self.p_datasets = [x/p_total for x in p_datasets]

self.samples_per_epoch = samples_per_epoch
self.max_gap = max_gap
self.num_train_frames = num_train_frames
self.processing = processing

self.p_reverse = p_reverse

def __len__(self):
return self.samples_per_epoch

def _sample_visible_ids(self, visible, num_ids=1, min_id=None, max_id=None):
""" Samples num_ids frames between min_id and max_id for which target is visible

args:
visible - 1d Tensor indicating whether target is visible for each frame
num_ids - number of frames to be samples
min_id - Minimum allowed frame number
max_id - Maximum allowed frame number

returns:
list - List of sampled frame numbers. None if not sufficient visible frames could be found.
"""
if min_id is None or min_id < 0:
min_id = 0
if max_id is None or max_id > len(visible):
max_id = len(visible)

valid_ids = [i for i in range(min_id, max_id) if visible[i]]

# No visible ids
if len(valid_ids) == 0:
return None

return random.choices(valid_ids, k=num_ids)

def __getitem__(self, index):
"""
args:
index (int): Index (dataset index)

returns:
TensorDict - dict containing all the data blocks
"""

# Select a dataset
# TODO ensure that the dataset can either be used independently, or wrapped with batch sampler
# dataset = self.datasets[index]
dataset = random.choices(self.datasets, self.p_datasets)[0]

is_video_dataset = dataset.is_video_sequence()

reverse_sequence = False
if self.p_reverse is not None:
reverse_sequence = random.random() < self.p_reverse

# Sample a sequence with enough visible frames
enough_visible_frames = False
while not enough_visible_frames:
# Sample a sequence
seq_id = random.randint(0, dataset.get_num_sequences() - 1)

# Sample frames
seq_info_dict = dataset.get_sequence_info(seq_id)
visible = seq_info_dict['visible']

enough_visible_frames = visible.type(torch.int64).sum().item() > 2 * (self.num_train_frames)

enough_visible_frames = enough_visible_frames or not is_video_dataset

if is_video_dataset:
train_frame_ids = None
sample_frame_ids = None
gap_increase = 0

# Sample train frames
while sample_frame_ids is None:
if gap_increase > 1000:
raise Exception('Frame not found')

if not reverse_sequence:
base_frame_id = self._sample_visible_ids(visible, num_ids=1, min_id=0,
max_id=len(visible)-self.num_train_frames+1)

train_frame_ids = base_frame_id
sample_frame_ids = self._sample_visible_ids(visible, min_id=train_frame_ids[0]+1,
max_id=train_frame_ids[0] + self.max_gap + gap_increase,
num_ids=self.num_train_frames-1)

# Increase gap until a frame is found
gap_increase += 5
else:
base_frame_id = self._sample_visible_ids(visible, num_ids=1, min_id=self.num_train_frames - 1,
max_id=len(visible))

train_frame_ids = base_frame_id
sample_frame_ids = self._sample_visible_ids(visible, min_id=train_frame_ids[0]+1 - self.max_gap - gap_increase,
max_id=train_frame_ids[0],
num_ids=self.num_train_frames-1)

# Increase gap until a frame is found
gap_increase += 5
train_frame_ids = train_frame_ids + sample_frame_ids
else:
# In case of image dataset, just repeat the image to generate synthetic video
train_frame_ids = [1]*self.num_train_frames

train_frame_ids = sorted(train_frame_ids, reverse=reverse_sequence)

train_frames, train_anno, meta_obj = dataset.get_frames(seq_id, train_frame_ids, seq_info_dict)

train_frames = train_frames[:len(train_frame_ids)]

train_masks = train_anno['mask'] if 'mask' in train_anno else None

data = TensorDict({'train_images': train_frames,
'train_masks': train_masks,
'train_anno': train_anno['bbox'],
'dataset': dataset.get_name()})

return self.processing(data)


class KYSSampler(torch.utils.data.Dataset):
def __init__(self, datasets, p_datasets, samples_per_epoch, sequence_sample_info, processing=no_processing,
sample_occluded_sequences=False):
Expand Down
Loading