From e01dad4424e5661cc443a4c7775213a18666dcf9 Mon Sep 17 00:00:00 2001 From: fatih <34196005+fcakyon@users.noreply.github.com> Date: Sat, 12 Feb 2022 13:41:00 +0300 Subject: [PATCH] handle nms-postprocess in edge cases (#370) * handle when only 1 prediction selected after nms postprocess * update * handle postprocess when no prediction * fix class agnostic nms * handle when single pred is left after postprocess --- sahi/postprocess/combine.py | 190 ++++-------------------------------- sahi/predict.py | 3 +- 2 files changed, 22 insertions(+), 171 deletions(-) diff --git a/sahi/postprocess/combine.py b/sahi/postprocess/combine.py index c7fcc23a..d5456ef1 100644 --- a/sahi/postprocess/combine.py +++ b/sahi/postprocess/combine.py @@ -13,6 +13,18 @@ def batched_nms(predictions: torch.tensor, match_metric: str = "IOU", match_threshold: float = 0.5): + """ + Apply non-maximum suppression to avoid detecting too many + overlapping bounding boxes for a given object. + Args: + predictions: (tensor) The location preds for the image + along with the class predscores, Shape: [num_boxes,5]. + match_metric: (str) IOU or IOS + match_threshold: (float) The overlap thresh for + match metric. + Returns: + A list of filtered indexes, Shape: [ ,] + """ scores = predictions[:, 4].squeeze() category_ids = predictions[:, 5].squeeze() keep_mask = torch.zeros_like(category_ids, dtype=torch.bool) @@ -21,7 +33,9 @@ def batched_nms(predictions: torch.tensor, match_metric: str = "IOU", match_thre curr_keep_indices = nms(predictions[curr_indices], match_metric, match_threshold) keep_mask[curr_indices[curr_keep_indices]] = True keep_indices = torch.where(keep_mask)[0] - return keep_indices[scores[keep_indices].sort(descending=True)[1]] + # sort selected indices by their scores + keep_indices = keep_indices[scores[keep_indices].sort(descending=True)[1]].tolist() + return keep_indices def nms( @@ -130,147 +144,6 @@ def nms( return keep -def batched_nmm_torch0( - object_predictions_as_tensor: torch.tensor, - object_prediction_list: ObjectPredictionList, - match_metric: str = "IOU", - match_threshold: float = 0.5, -): - category_ids = object_predictions_as_tensor[:, 5].squeeze() - keep_mask = torch.zeros_like(category_ids, dtype=torch.bool) - for category_id in torch.unique(category_ids): - curr_indices = torch.where(category_ids == category_id)[0] - curr_object_predictions = object_prediction_list[curr_indices] - curr_keep_indices, curr_object_predictions = nmm_torch0( - object_predictions_as_tensor[curr_indices], curr_object_predictions, match_metric, match_threshold - ) - keep_mask[curr_indices[curr_keep_indices]] = True - object_prediction_list[curr_indices] = curr_object_predictions - keep_indices = torch.where(keep_mask)[0] - scores = object_prediction_list.totensor()[keep_indices, 4].squeeze() - return keep_indices[scores.sort(descending=True)[1]], object_prediction_list - - -def nmm_torch0( - object_predictions_as_tensor: torch.tensor, - object_prediction_list: ObjectPredictionList, - match_metric: str = "IOU", - match_threshold: float = 0.5, -): - """ - Apply non-maximum merging to avoid detecting too many - overlapping bounding boxes for a given object. - Args: - object_predictions_as_tensor: (tensor) The location preds for the image - along with the class predscores, Shape: [num_boxes,5]. - object_predictions_as_list: ObjectPredictionList Object prediction objects - to be merged. - match_metric: (str) IOU or IOS - match_threshold: (float) The overlap thresh for - match metric. - Returns: - A list of filtered indexes, Shape: [ ,] - """ - - # we extract coordinates for every - # prediction box present in P - x1 = object_predictions_as_tensor[:, 0] - y1 = object_predictions_as_tensor[:, 1] - x2 = object_predictions_as_tensor[:, 2] - y2 = object_predictions_as_tensor[:, 3] - - # we extract the confidence scores as well - scores = object_predictions_as_tensor[:, 4] - - # calculate area of every block in P - areas = (x2 - x1) * (y2 - y1) - - # sort the prediction boxes in P - # according to their confidence scores - order = scores.argsort() - - # initialise an empty list for - # filtered prediction boxes - keep = [] - - while len(order) > 0: - # extract the index of the - # prediction with highest score - # we call this prediction S - idx = order[-1] - - # push S in filtered predictions list - keep.append(idx.tolist()) - - # remove S from P - order = order[:-1] - - # sanity check - if len(order) == 0: - break - - # select coordinates of BBoxes according to - # the indices in order - xx1 = torch.index_select(x1, dim=0, index=order) - xx2 = torch.index_select(x2, dim=0, index=order) - yy1 = torch.index_select(y1, dim=0, index=order) - yy2 = torch.index_select(y2, dim=0, index=order) - - # find the coordinates of the intersection boxes - xx1 = torch.max(xx1, x1[idx]) - yy1 = torch.max(yy1, y1[idx]) - xx2 = torch.min(xx2, x2[idx]) - yy2 = torch.min(yy2, y2[idx]) - - # find height and width of the intersection boxes - w = xx2 - xx1 - h = yy2 - yy1 - - # take max with 0.0 to avoid negative w and h - # due to non-overlapping boxes - w = torch.clamp(w, min=0.0) - h = torch.clamp(h, min=0.0) - - # find the intersection area - inter = w * h - - # find the areas of BBoxes according the indices in order - rem_areas = torch.index_select(areas, dim=0, index=order) - - if match_metric == "IOU": - # find the union of every prediction T in P - # with the prediction S - # Note that areas[idx] represents area of S - union = (rem_areas - inter) + areas[idx] - # find the IoU of every prediction in P with S - match_metric_value = inter / union - - elif match_metric == "IOS": - # find the smaller area of every prediction T in P - # with the prediction S - # Note that areas[idx] represents area of S - smaller = torch.min(rem_areas, areas[idx]) - # find the IoS of every prediction in P with S - match_metric_value = inter / smaller - else: - raise ValueError() - - # keep the boxes with IoU/IoS less than thresh_iou - mask = match_metric_value < match_threshold - matched_box_indices = order[(mask == False).nonzero().flatten()].flip(dims=(0,)) - unmatched_indices = order[(mask == True).nonzero().flatten()] - - # merge matching predictions - for matched_box_ind in matched_box_indices: - object_prediction_list[idx] = merge_object_prediction_pair( - object_prediction_list[idx].tolist(), object_prediction_list[matched_box_ind].tolist() - ) - - # update box pool - order = unmatched_indices[scores[unmatched_indices].argsort()] - return keep, object_prediction_list - - def batched_greedy_nmm( object_predictions_as_tensor: torch.tensor, match_metric: str = "IOU", @@ -608,34 +481,9 @@ def __call__( object_predictions_as_torch, match_threshold=self.match_threshold, match_metric=self.match_metric ) - selected_object_predictions = object_prediction_list[keep.tolist()].tolist() - - return selected_object_predictions - - -class NMMPostprocess0(PostprocessPredictions): - def __call__( - self, - object_predictions: List[ObjectPrediction], - ): - object_prediction_list = ObjectPredictionList(object_predictions) - object_predictions_as_torch = object_prediction_list.totensor() - if self.class_agnostic: - keep, object_prediction_list = nmm_torch0( - object_predictions_as_torch, - object_prediction_list, - match_threshold=self.match_threshold, - match_metric=self.match_metric, - ) - else: - keep, object_prediction_list = batched_nmm_torch0( - object_predictions_as_torch, - object_prediction_list, - match_threshold=self.match_threshold, - match_metric=self.match_metric, - ) - - selected_object_predictions = object_prediction_list[keep.tolist()].tolist() + selected_object_predictions = object_prediction_list[keep].tolist() + if not isinstance(selected_object_predictions, list): + selected_object_predictions = [selected_object_predictions] return selected_object_predictions @@ -744,5 +592,7 @@ def __call__( ) selected_object_predictions = object_prediction_list[keep].tolist() + if not isinstance(selected_object_predictions, list): + selected_object_predictions = [selected_object_predictions] return selected_object_predictions diff --git a/sahi/predict.py b/sahi/predict.py index 3817b19c..8316a358 100644 --- a/sahi/predict.py +++ b/sahi/predict.py @@ -286,7 +286,8 @@ def get_sliced_prediction( ) # merge matching predictions - object_prediction_list = postprocess(object_prediction_list) + if len(object_prediction_list) > 0: + object_prediction_list = postprocess(object_prediction_list) return PredictionResult( image=image, object_prediction_list=object_prediction_list, durations_in_seconds=durations_in_seconds