Skip to content

Commit

Permalink
seperate mmdet/yolov5 utils (#130)
Browse files Browse the repository at this point in the history
* seperate mmdet and yolov5 utils

* fix test imports

* fix notebook imports
  • Loading branch information
fcakyon authored Jun 12, 2021
1 parent 4f3f01d commit 9ee4b00
Show file tree
Hide file tree
Showing 7 changed files with 57 additions and 93 deletions.
2 changes: 1 addition & 1 deletion demo/inference_for_mmdetection.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
"outputs": [],
"source": [
"# arrange an instance segmentation model for test\n",
"from sahi.utils.test import (\n",
"from sahi.utils.mmdet import (\n",
" download_mmdet_cascade_mask_rcnn_model,\n",
" download_mmdet_config,\n",
")\n",
Expand Down
2 changes: 1 addition & 1 deletion demo/inference_for_yolov5.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
"outputs": [],
"source": [
"# arrange an instance segmentation model for test\n",
"from sahi.utils.test import (\n",
"from sahi.utils.yolov5 import (\n",
" download_yolov5s6_model,\n",
" download_image,\n",
")\n",
Expand Down
22 changes: 0 additions & 22 deletions sahi/utils/test.py → sahi/utils/mmdet.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,6 @@ class MmdetTestConstants:
print("warning: mmdet installation not found, omitting MmdetTestConstants")


class Yolov5TestConstants:
YOLOV5S6_MODEL_URL = "https://github.com/ultralytics/yolov5/releases/download/v5.0/yolov5s6.pt"
YOLOV5S6_MODEL_PATH = "tests/data/models/yolov5/yolov5s6.pt"

YOLOV5M6_MODEL_URL = "https://github.com/ultralytics/yolov5/releases/download/v5.0/yolov5m6.pt"
YOLOV5M6_MODEL_PATH = "tests/data/models/yolov5/yolov5m6.pt"


def download_mmdet_cascade_mask_rcnn_model(destination_path: Optional[str] = None):

if destination_path is None:
Expand Down Expand Up @@ -72,20 +64,6 @@ def download_mmdet_retinanet_model(destination_path: Optional[str] = None):
)


def download_yolov5s6_model(destination_path: Optional[str] = None):

if destination_path is None:
destination_path = Yolov5TestConstants.YOLOV5S6_MODEL_PATH

Path(destination_path).parent.mkdir(parents=True, exist_ok=True)

if not path.exists(destination_path):
urllib.request.urlretrieve(
Yolov5TestConstants.YOLOV5S6_MODEL_URL,
destination_path,
)


def download_mmdet_config(
model_name: str = "cascade_rcnn",
config_file_name: str = "cascade_mask_rcnn_r50_fpn_1x_coco.py",
Expand Down
26 changes: 26 additions & 0 deletions sahi/utils/yolov5.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import urllib.request
from os import path
from pathlib import Path
from typing import Optional


class Yolov5TestConstants:
YOLOV5S6_MODEL_URL = "https://github.com/ultralytics/yolov5/releases/download/v5.0/yolov5s6.pt"
YOLOV5S6_MODEL_PATH = "tests/data/models/yolov5/yolov5s6.pt"

YOLOV5M6_MODEL_URL = "https://github.com/ultralytics/yolov5/releases/download/v5.0/yolov5m6.pt"
YOLOV5M6_MODEL_PATH = "tests/data/models/yolov5/yolov5m6.pt"


def download_yolov5s6_model(destination_path: Optional[str] = None):

if destination_path is None:
destination_path = Yolov5TestConstants.YOLOV5S6_MODEL_PATH

Path(destination_path).parent.mkdir(parents=True, exist_ok=True)

if not path.exists(destination_path):
urllib.request.urlretrieve(
Yolov5TestConstants.YOLOV5S6_MODEL_URL,
destination_path,
)
66 changes: 17 additions & 49 deletions tests/test_mmdetectionmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np
from sahi.utils.cv import read_image

from sahi.utils.test import (
from sahi.utils.mmdet import (
MmdetTestConstants,
download_mmdet_cascade_mask_rcnn_model,
download_mmdet_retinanet_model,
Expand Down Expand Up @@ -227,33 +227,15 @@ def test_create_original_predictions_from_object_prediction_list_with_mask_outpu

# compare
self.assertEqual(len(original_predictions_1), len(original_predictions_2)) # 2
self.assertEqual(
len(original_predictions_1[0]), len(original_predictions_2[0])
) # 80
self.assertEqual(
len(original_predictions_1[0][2]), len(original_predictions_2[0][2])
) # 25
self.assertEqual(
type(original_predictions_1[0]), type(original_predictions_2[0])
) # list
self.assertEqual(
original_predictions_1[0][2].dtype, original_predictions_2[0][2].dtype
) # float32
self.assertEqual(
original_predictions_1[0][0][0].dtype, original_predictions_2[0][0][0].dtype
) # float32
self.assertEqual(
original_predictions_1[1][0][0].dtype, original_predictions_2[1][0][0].dtype
) # bool
self.assertEqual(
len(original_predictions_1[0][0][0]), len(original_predictions_2[0][0][0])
) # 5
self.assertEqual(
len(original_predictions_1[0][1]), len(original_predictions_1[0][1])
) # 0
self.assertEqual(
original_predictions_1[0][1].shape, original_predictions_1[0][1].shape
) # (0, 5)
self.assertEqual(len(original_predictions_1[0]), len(original_predictions_2[0])) # 80
self.assertEqual(len(original_predictions_1[0][2]), len(original_predictions_2[0][2])) # 25
self.assertEqual(type(original_predictions_1[0]), type(original_predictions_2[0])) # list
self.assertEqual(original_predictions_1[0][2].dtype, original_predictions_2[0][2].dtype) # float32
self.assertEqual(original_predictions_1[0][0][0].dtype, original_predictions_2[0][0][0].dtype) # float32
self.assertEqual(original_predictions_1[1][0][0].dtype, original_predictions_2[1][0][0].dtype) # bool
self.assertEqual(len(original_predictions_1[0][0][0]), len(original_predictions_2[0][0][0])) # 5
self.assertEqual(len(original_predictions_1[0][1]), len(original_predictions_1[0][1])) # 0
self.assertEqual(original_predictions_1[0][1].shape, original_predictions_1[0][1].shape) # (0, 5)

def test_create_original_predictions_from_object_prediction_list_without_mask_output(
self,
Expand Down Expand Up @@ -290,27 +272,13 @@ def test_create_original_predictions_from_object_prediction_list_without_mask_ou

# compare
self.assertEqual(len(original_predictions_1), len(original_predictions_2)) # 80
self.assertEqual(
len(original_predictions_1[2]), len(original_predictions_2[2])
) # 97
self.assertEqual(
type(original_predictions_1), type(original_predictions_2)
) # list
self.assertEqual(
original_predictions_1[2].dtype, original_predictions_2[2].dtype
) # float32
self.assertEqual(
original_predictions_1[2][0].dtype, original_predictions_2[2][0].dtype
) # float32
self.assertEqual(
len(original_predictions_1[2][0]), len(original_predictions_2[2][0])
) # 5
self.assertEqual(
len(original_predictions_1[1]), len(original_predictions_1[1])
) # 0
self.assertEqual(
original_predictions_1[1].shape, original_predictions_1[1].shape
) # (0, 5)
self.assertEqual(len(original_predictions_1[2]), len(original_predictions_2[2])) # 97
self.assertEqual(type(original_predictions_1), type(original_predictions_2)) # list
self.assertEqual(original_predictions_1[2].dtype, original_predictions_2[2].dtype) # float32
self.assertEqual(original_predictions_1[2][0].dtype, original_predictions_2[2][0].dtype) # float32
self.assertEqual(len(original_predictions_1[2][0]), len(original_predictions_2[2][0])) # 5
self.assertEqual(len(original_predictions_1[1]), len(original_predictions_1[1])) # 0
self.assertEqual(original_predictions_1[1].shape, original_predictions_1[1].shape) # (0, 5)


if __name__ == "__main__":
Expand Down
20 changes: 9 additions & 11 deletions tests/test_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@ def test_object_prediction(self):
def test_get_prediction_mmdet(self):
from sahi.model import MmdetDetectionModel
from sahi.predict import get_prediction

from sahi.utils.test import (
from sahi.utils.mmdet import (
MmdetTestConstants,
download_mmdet_cascade_mask_rcnn_model,
)
Expand Down Expand Up @@ -76,8 +75,7 @@ def test_get_prediction_mmdet(self):
def test_get_prediction_yolov5(self):
from sahi.model import Yolov5DetectionModel
from sahi.predict import get_prediction

from sahi.utils.test import (
from sahi.utils.yolov5 import (
Yolov5TestConstants,
download_yolov5s6_model,
)
Expand Down Expand Up @@ -125,8 +123,7 @@ def test_get_prediction_yolov5(self):
def test_get_sliced_prediction_mmdet(self):
from sahi.model import MmdetDetectionModel
from sahi.predict import get_sliced_prediction

from sahi.utils.test import (
from sahi.utils.mmdet import (
MmdetTestConstants,
download_mmdet_cascade_mask_rcnn_model,
)
Expand Down Expand Up @@ -193,7 +190,7 @@ def test_get_sliced_prediction_yolov5(self):
from sahi.model import Yolov5DetectionModel
from sahi.predict import get_sliced_prediction

from sahi.utils.test import (
from sahi.utils.yolov5 import (
Yolov5TestConstants,
download_yolov5s6_model,
)
Expand Down Expand Up @@ -257,13 +254,14 @@ def test_get_sliced_prediction_yolov5(self):

def test_coco_json_prediction(self):
from sahi.predict import predict

from sahi.utils.test import (
MmdetTestConstants,
from sahi.utils.yolov5 import (
Yolov5TestConstants,
download_mmdet_cascade_mask_rcnn_model,
download_yolov5s6_model,
)
from sahi.utils.mmdet import (
MmdetTestConstants,
download_mmdet_cascade_mask_rcnn_model,
)

# init model
download_mmdet_cascade_mask_rcnn_model()
Expand Down
12 changes: 3 additions & 9 deletions tests/test_yolov5model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np
from sahi.utils.cv import read_image

from sahi.utils.test import (
from sahi.utils.yolov5 import (
Yolov5TestConstants,
download_yolov5s6_model,
)
Expand Down Expand Up @@ -63,10 +63,7 @@ def test_perform_inference(self):
predicted_bbox = list(map(int, box[:4].tolist()))
margin = 2
for ind, point in enumerate(predicted_bbox):
assert (
point < desired_bbox[ind] + margin
and point > desired_bbox[ind] - margin
)
assert point < desired_bbox[ind] + margin and point > desired_bbox[ind] - margin
self.assertEqual(len(original_predictions.names), 80)

def test_convert_original_predictions(self):
Expand Down Expand Up @@ -102,10 +99,7 @@ def test_convert_original_predictions(self):
predicted_bbox = object_prediction_list[0].bbox.to_coco_bbox()
margin = 2
for ind, point in enumerate(predicted_bbox):
assert (
point < desired_bbox[ind] + margin
and point > desired_bbox[ind] - margin
)
assert point < desired_bbox[ind] + margin and point > desired_bbox[ind] - margin
self.assertEqual(object_prediction_list[5].category.id, 2)
self.assertEqual(object_prediction_list[5].category.name, "car")
self.assertEqual(
Expand Down

0 comments on commit 9ee4b00

Please sign in to comment.