-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathroihead_mod.py
103 lines (90 loc) · 3.63 KB
/
roihead_mod.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed Oct 23 15:47:08 2019
@author: manoj
"""
from torchvision.models.detection.roi_heads import RoIHeads,fastrcnn_loss,maskrcnn_inference,maskrcnn_loss
import torch.nn.functional as F
import torch
from torchvision.ops import boxes as box_ops
from torchvision.ops import misc as misc_nn_ops
from torchvision.ops import roi_align
from torchvision.models.detection import _utils as det_utils
class ModifiedRoIHead(RoIHeads):
def __init__(self,
box_roi_pool,
box_head,
box_predictor,
# Faster R-CNN training
fg_iou_thresh, bg_iou_thresh,
batch_size_per_image, positive_fraction,
bbox_reg_weights,
# Faster R-CNN inference
score_thresh,
nms_thresh,
detections_per_img,
# Mask
mask_roi_pool=None,
mask_head=None,
mask_predictor=None,
keypoint_roi_pool=None,
keypoint_head=None,
keypoint_predictor=None,
):
super().__init__(box_roi_pool,
box_head,
box_predictor,
# Faster R-CNN training
fg_iou_thresh, bg_iou_thresh,
batch_size_per_image, positive_fraction,
bbox_reg_weights,
# Faster R-CNN inference
score_thresh,
nms_thresh,
detections_per_img,
# Mask
mask_roi_pool=None,
mask_head=None,
mask_predictor=None,
keypoint_roi_pool=None,
keypoint_head=None,
keypoint_predictor=None,
)
print (" ----- Using fake ROIHead -----")
def forward(self, features, proposals, image_shapes, targets=None):
"""
Arguments:
features (List[Tensor])
proposals (List[Tensor[N, 4]])
image_shapes (List[Tuple[H, W]])
targets (List[Dict])
"""
if targets is not None:
for t in targets:
assert t["boxes"].dtype.is_floating_point, 'target boxes must of float type'
assert t["labels"].dtype == torch.int64, 'target labels must of int64 type'
if self.has_keypoint:
assert t["keypoints"].dtype == torch.float32, 'target keypoints must of float type'
if self.training:
proposals, matched_idxs, labels, regression_targets = self.select_training_samples(proposals, targets)
box_features = self.box_roi_pool(features, proposals, image_shapes)
box_features = self.box_head(box_features)
class_logits, box_regression = self.box_predictor(box_features)
result, losses = [], {}
if self.training:
loss_classifier, loss_box_reg = fastrcnn_loss(
class_logits, box_regression, labels, regression_targets)
losses = dict(loss_classifier=loss_classifier, loss_box_reg=loss_box_reg)
else:
boxes, scores, labels = self.postprocess_detections(class_logits, box_regression, proposals, image_shapes)
num_images = len(boxes)
for i in range(num_images):
result.append(
dict(
boxes=boxes[i],
labels=labels[i],
scores=scores[i],
)
)
return result, losses