diff --git a/model.py b/model.py index 302e111..fb91c65 100644 --- a/model.py +++ b/model.py @@ -783,7 +783,10 @@ def refine_detections(rois, probs, deltas, window, config): # Filter out low confidence boxes if config.DETECTION_MIN_CONFIDENCE: keep_bool = keep_bool & (class_scores >= config.DETECTION_MIN_CONFIDENCE) - keep = torch.nonzero(keep_bool)[:,0] + detections = torch.nonzero(keep_bool) + if not len(detections): + return + keep = detections[:,0] # Apply per-class NMS pre_nms_class_ids = class_ids[keep.data] @@ -1596,7 +1599,8 @@ def detect(self, images): # Run object detection detections, mrcnn_mask = self.predict([molded_images, image_metas], mode='inference') - + if detections is None: + return # Convert to numpy detections = detections.data.cpu().numpy() mrcnn_mask = mrcnn_mask.permute(0, 1, 3, 4, 2).data.cpu().numpy() @@ -1671,7 +1675,8 @@ def set_bn_eval(m): # Detections # output is [batch, num_detections, (y1, x1, y2, x2, class_id, score)] in image coordinates detections = detection_layer(self.config, rpn_rois, mrcnn_class, mrcnn_bbox, image_metas) - + if detections is None: + return None, None # Convert boxes to normalized coordinates # TODO: let DetectionLayer return normalized coordinates to avoid # unnecessary conversions