Skip to content

Commit

Permalink
improve result visuals/plots (#271)
Browse files Browse the repository at this point in the history
* improve result visuals/plots

* reduce gpu to cpu copy overhead for yolov5

* dont export visuals by default
  • Loading branch information
fcakyon authored Nov 15, 2021
1 parent 41a67c2 commit 21ecb28
Show file tree
Hide file tree
Showing 6 changed files with 246 additions and 190 deletions.
82 changes: 44 additions & 38 deletions demo/inference_for_mmdetection.ipynb

Large diffs are not rendered by default.

99 changes: 37 additions & 62 deletions demo/inference_for_yolov5.ipynb

Large diffs are not rendered by default.

31 changes: 20 additions & 11 deletions sahi/model.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,22 @@
# OBSS SAHI Tool
# Code written by Fatih C Akyon, 2020.

import logging
import os
from typing import Dict, List, Optional, Union

import numpy as np

from sahi.prediction import ObjectPrediction
from sahi.utils.torch import cuda_is_available, empty_cuda_cache

logger = logging.getLogger(__name__)
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=os.environ.get("LOGLEVEL", "INFO").upper(),
)


class DetectionModel:
def __init__(
Expand Down Expand Up @@ -299,15 +308,15 @@ def _create_object_prediction_list_from_original_predictions(

# ignore invalid predictions
if bbox[0] > bbox[2] or bbox[1] > bbox[3] or bbox[0] < 0 or bbox[1] < 0 or bbox[2] < 0 or bbox[3] < 0:
print(f"ignoring invalid prediction with bbox: {bbox}")
logger.warning(f"ignoring invalid prediction with bbox: {bbox}")
continue
if full_shape is not None and (
bbox[1] > full_shape[0]
or bbox[3] > full_shape[0]
or bbox[0] > full_shape[1]
or bbox[2] > full_shape[1]
):
print(f"ignoring invalid prediction with bbox: {bbox}")
logger.warning(f"ignoring invalid prediction with bbox: {bbox}")
continue

object_prediction = ObjectPrediction(
Expand Down Expand Up @@ -461,29 +470,29 @@ def _create_object_prediction_list_from_original_predictions(
original_predictions = self._original_predictions

# handle only first image (batch=1)
predictions_in_xyxy_format = original_predictions.xyxy[0]
predictions_in_xyxy_format = original_predictions.xyxy[0].cpu().detach().numpy()

object_prediction_list = []

# process predictions
for prediction in predictions_in_xyxy_format:
x1 = int(prediction[0].item())
y1 = int(prediction[1].item())
x2 = int(prediction[2].item())
y2 = int(prediction[3].item())
x1 = int(prediction[0])
y1 = int(prediction[1])
x2 = int(prediction[2])
y2 = int(prediction[3])
bbox = [x1, y1, x2, y2]
score = prediction[4].item()
category_id = int(prediction[5].item())
score = prediction[4]
category_id = int(prediction[5])
category_name = original_predictions.names[category_id]

# ignore invalid predictions
if bbox[0] > bbox[2] or bbox[1] > bbox[3] or bbox[0] < 0 or bbox[1] < 0 or bbox[2] < 0 or bbox[3] < 0:
print(f"ignoring invalid prediction with bbox: {bbox}")
logger.warning(f"ignoring invalid prediction with bbox: {bbox}")
continue
if full_shape is not None and (
bbox[1] > full_shape[0] or bbox[3] > full_shape[0] or bbox[0] > full_shape[1] or bbox[2] > full_shape[1]
):
print(f"ignoring invalid prediction with bbox: {bbox}")
logger.warning(f"ignoring invalid prediction with bbox: {bbox}")
continue

object_prediction = ObjectPrediction(
Expand Down
81 changes: 41 additions & 40 deletions sahi/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,15 +280,15 @@ def predict(
postprocess_match_metric: str = "IOS",
postprocess_match_threshold: float = 0.5,
postprocess_class_agnostic: bool = False,
export_visual: bool = True,
export_visual: bool = False,
export_pickle: bool = False,
export_crop: bool = False,
dataset_json_path: bool = None,
project: str = "runs/predict",
name: str = "exp",
visual_bbox_thickness: int = 1,
visual_text_size: float = 0.3,
visual_text_thickness: int = 1,
visual_bbox_thickness: int = None,
visual_text_size: float = None,
visual_text_thickness: int = None,
visual_export_format: str = "png",
verbose: int = 1,
):
Expand Down Expand Up @@ -469,43 +469,44 @@ def predict(
coco_prediction_json = coco_prediction.json
if coco_prediction_json["bbox"]:
coco_json.append(coco_prediction_json)
# convert ground truth annotations to object_prediction_list
coco_image: CocoImage = coco.images[ind]
object_prediction_gt_list: List[ObjectPrediction] = []
for coco_annotation in coco_image.annotations:
coco_annotation_dict = coco_annotation.json
category_name = coco_annotation.category_name
full_shape = [coco_image.height, coco_image.width]
object_prediction_gt = ObjectPrediction.from_coco_annotation_dict(
annotation_dict=coco_annotation_dict, category_name=category_name, full_shape=full_shape
if export_visual:
# convert ground truth annotations to object_prediction_list
coco_image: CocoImage = coco.images[ind]
object_prediction_gt_list: List[ObjectPrediction] = []
for coco_annotation in coco_image.annotations:
coco_annotation_dict = coco_annotation.json
category_name = coco_annotation.category_name
full_shape = [coco_image.height, coco_image.width]
object_prediction_gt = ObjectPrediction.from_coco_annotation_dict(
annotation_dict=coco_annotation_dict, category_name=category_name, full_shape=full_shape
)
object_prediction_gt_list.append(object_prediction_gt)
# export visualizations with ground truths
output_dir = str(visual_with_gt_dir / Path(relative_filepath).parent)
color = (0, 255, 0) # original annotations in green
result = visualize_object_predictions(
np.ascontiguousarray(image_as_pil),
object_prediction_list=object_prediction_gt_list,
rect_th=visual_bbox_thickness,
text_size=visual_text_size,
text_th=visual_text_thickness,
color=color,
output_dir=None,
file_name=None,
export_format=None,
)
color = (255, 0, 0) # model predictions in red
_ = visualize_object_predictions(
result["image"],
object_prediction_list=object_prediction_list,
rect_th=visual_bbox_thickness,
text_size=visual_text_size,
text_th=visual_text_thickness,
color=color,
output_dir=output_dir,
file_name=filename_without_extension,
export_format=visual_export_format,
)
object_prediction_gt_list.append(object_prediction_gt)
# export visualizations with ground truths
output_dir = str(visual_with_gt_dir / Path(relative_filepath).parent)
color = (0, 255, 0) # original annotations in green
result = visualize_object_predictions(
np.ascontiguousarray(image_as_pil),
object_prediction_list=object_prediction_gt_list,
rect_th=visual_bbox_thickness,
text_size=visual_text_size,
text_th=visual_text_thickness,
color=color,
output_dir=None,
file_name=None,
export_format=None,
)
color = (255, 0, 0) # model predictions in red
_ = visualize_object_predictions(
result["image"],
object_prediction_list=object_prediction_list,
rect_th=visual_bbox_thickness,
text_size=visual_text_size,
text_th=visual_text_thickness,
color=color,
output_dir=output_dir,
file_name=filename_without_extension,
export_format=visual_export_format,
)

time_start = time.time()
# export prediction boxes
Expand Down
10 changes: 5 additions & 5 deletions sahi/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,15 +164,15 @@ def __init__(
self.object_prediction_list: List[ObjectPrediction] = object_prediction_list
self.durations_in_seconds = durations_in_seconds

def export_visuals(self, export_dir: str):
def export_visuals(self, export_dir: str, text_size: float = None, rect_th: int = None):
Path(export_dir).mkdir(parents=True, exist_ok=True)
visualize_object_predictions(
image=np.ascontiguousarray(self.image),
object_prediction_list=self.object_prediction_list,
rect_th=1,
text_size=0.3,
text_th=1,
color=(0, 0, 0),
rect_th=rect_th,
text_size=text_size,
text_th=None,
color=None,
output_dir=export_dir,
file_name="prediction_visual",
export_format="png",
Expand Down
Loading

0 comments on commit 21ecb28

Please sign in to comment.