Skip to content

Commit

Permalink
improve test durations (#252)
Browse files Browse the repository at this point in the history
* improve test durations

* update tests

* fix test
  • Loading branch information
fcakyon authored Oct 19, 2021
1 parent 842c39f commit 55b2641
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 37 deletions.
45 changes: 22 additions & 23 deletions tests/test_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,12 @@ def test_get_prediction_mmdet(self):

# get full sized prediction
prediction_result = get_prediction(
image=image,
detection_model=mmdet_detection_model,
shift_amount=[0, 0],
full_shape=None,
image=image, detection_model=mmdet_detection_model, shift_amount=[0, 0], full_shape=None, image_size=320
)
object_prediction_list = prediction_result.object_prediction_list

# compare
self.assertEqual(len(object_prediction_list), 23)
self.assertEqual(len(object_prediction_list), 4)
num_person = 0
for object_prediction in object_prediction_list:
if object_prediction.category.name == "person":
Expand All @@ -63,23 +60,23 @@ def test_get_prediction_mmdet(self):
for object_prediction in object_prediction_list:
if object_prediction.category.name == "truck":
num_truck += 1
self.assertEqual(num_truck, 3)
self.assertEqual(num_truck, 0)
num_car = 0
for object_prediction in object_prediction_list:
if object_prediction.category.name == "car":
num_car += 1
self.assertEqual(num_car, 20)
self.assertEqual(num_car, 3)

def test_get_prediction_yolov5(self):
from sahi.model import Yolov5DetectionModel
from sahi.predict import get_prediction
from sahi.utils.yolov5 import Yolov5TestConstants, download_yolov5s6_model
from sahi.utils.yolov5 import Yolov5TestConstants, download_yolov5n_model

# init model
download_yolov5s6_model()
download_yolov5n_model()

yolov5_detection_model = Yolov5DetectionModel(
model_path=Yolov5TestConstants.YOLOV5S6_MODEL_PATH,
model_path=Yolov5TestConstants.YOLOV5N_MODEL_PATH,
confidence_threshold=0.3,
device=None,
category_remapping=None,
Expand All @@ -98,7 +95,7 @@ def test_get_prediction_yolov5(self):
object_prediction_list = prediction_result.object_prediction_list

# compare
self.assertEqual(len(object_prediction_list), 12)
self.assertEqual(len(object_prediction_list), 15)
num_person = 0
for object_prediction in object_prediction_list:
if object_prediction.category.name == "person":
Expand All @@ -113,7 +110,7 @@ def test_get_prediction_yolov5(self):
for object_prediction in object_prediction_list:
if object_prediction.category.name == "car":
num_car += 1
self.assertEqual(num_car, 12)
self.assertEqual(num_car, 15)

def test_get_sliced_prediction_mmdet(self):
from sahi.model import MmdetDetectionModel
Expand Down Expand Up @@ -144,10 +141,12 @@ def test_get_sliced_prediction_mmdet(self):
match_metric = "IOS"
match_threshold = 0.5
class_agnostic = True
image_size = 320

# get sliced prediction
prediction_result = get_sliced_prediction(
image=image_path,
image_size=image_size,
detection_model=mmdet_detection_model,
slice_height=slice_height,
slice_width=slice_width,
Expand All @@ -162,7 +161,7 @@ def test_get_sliced_prediction_mmdet(self):
object_prediction_list = prediction_result.object_prediction_list

# compare
self.assertEqual(len(object_prediction_list), 24)
self.assertEqual(len(object_prediction_list), 13)
num_person = 0
for object_prediction in object_prediction_list:
if object_prediction.category.name == "person":
Expand All @@ -172,23 +171,23 @@ def test_get_sliced_prediction_mmdet(self):
for object_prediction in object_prediction_list:
if object_prediction.category.name == "truck":
num_truck += 2
self.assertEqual(num_truck, 4)
self.assertEqual(num_truck, 0)
num_car = 0
for object_prediction in object_prediction_list:
if object_prediction.category.name == "car":
num_car += 1
self.assertEqual(num_car, 22)
self.assertEqual(num_car, 13)

def test_get_sliced_prediction_yolov5(self):
from sahi.model import Yolov5DetectionModel
from sahi.predict import get_sliced_prediction
from sahi.utils.yolov5 import Yolov5TestConstants, download_yolov5s6_model
from sahi.utils.yolov5 import Yolov5TestConstants, download_yolov5n_model

# init model
download_yolov5s6_model()
download_yolov5n_model()

yolov5_detection_model = Yolov5DetectionModel(
model_path=Yolov5TestConstants.YOLOV5S6_MODEL_PATH,
model_path=Yolov5TestConstants.YOLOV5N_MODEL_PATH,
confidence_threshold=0.3,
device=None,
category_remapping=None,
Expand Down Expand Up @@ -225,7 +224,7 @@ def test_get_sliced_prediction_yolov5(self):
object_prediction_list = prediction_result.object_prediction_list

# compare
self.assertEqual(len(object_prediction_list), 21)
self.assertEqual(len(object_prediction_list), 19)
num_person = 0
for object_prediction in object_prediction_list:
if object_prediction.category.name == "person":
Expand All @@ -240,12 +239,12 @@ def test_get_sliced_prediction_yolov5(self):
for object_prediction in object_prediction_list:
if object_prediction.category.name == "car":
num_car += 1
self.assertEqual(num_car, 21)
self.assertEqual(num_car, 19)

def test_coco_json_prediction(self):
from sahi.predict import predict
from sahi.utils.mmdet import MmdetTestConstants, download_mmdet_cascade_mask_rcnn_model
from sahi.utils.yolov5 import Yolov5TestConstants, download_yolov5s6_model
from sahi.utils.yolov5 import Yolov5TestConstants, download_yolov5n_model

# init model
download_mmdet_cascade_mask_rcnn_model()
Expand Down Expand Up @@ -292,7 +291,7 @@ def test_coco_json_prediction(self):
)

# init model
download_yolov5s6_model()
download_yolov5n_model()

# prepare paths
dataset_json_path = "tests/data/coco_utils/terrain_all_coco.json"
Expand All @@ -304,7 +303,7 @@ def test_coco_json_prediction(self):
shutil.rmtree(project_dir)
predict(
model_type="yolov5",
model_path=Yolov5TestConstants.YOLOV5S6_MODEL_PATH,
model_path=Yolov5TestConstants.YOLOV5N_MODEL_PATH,
model_config_path=None,
model_confidence_threshold=0.4,
model_device=None,
Expand Down
28 changes: 14 additions & 14 deletions tests/test_yolov5model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,17 @@
import numpy as np

from sahi.utils.cv import read_image
from sahi.utils.yolov5 import Yolov5TestConstants, download_yolov5s6_model
from sahi.utils.yolov5 import Yolov5TestConstants, download_yolov5n_model, download_yolov5s6_model


class TestYolov5DetectionModel(unittest.TestCase):
def test_load_model(self):
from sahi.model import Yolov5DetectionModel

download_yolov5s6_model()
download_yolov5n_model()

yolov5_detection_model = Yolov5DetectionModel(
model_path=Yolov5TestConstants.YOLOV5S6_MODEL_PATH,
model_path=Yolov5TestConstants.YOLOV5N_MODEL_PATH,
confidence_threshold=0.3,
device=None,
category_remapping=None,
Expand All @@ -29,10 +29,10 @@ def test_perform_inference(self):
from sahi.model import Yolov5DetectionModel

# init model
download_yolov5s6_model()
download_yolov5n_model()

yolov5_detection_model = Yolov5DetectionModel(
model_path=Yolov5TestConstants.YOLOV5S6_MODEL_PATH,
model_path=Yolov5TestConstants.YOLOV5N_MODEL_PATH,
confidence_threshold=0.5,
device=None,
category_remapping=None,
Expand All @@ -56,7 +56,7 @@ def test_perform_inference(self):
break

# compare
desired_bbox = [321, 322, 383, 362]
desired_bbox = [321, 325, 384, 365]
predicted_bbox = list(map(int, box[:4].tolist()))
margin = 2
for ind, point in enumerate(predicted_bbox):
Expand All @@ -67,10 +67,10 @@ def test_convert_original_predictions(self):
from sahi.model import Yolov5DetectionModel

# init model
download_yolov5s6_model()
download_yolov5n_model()

yolov5_detection_model = Yolov5DetectionModel(
model_path=Yolov5TestConstants.YOLOV5S6_MODEL_PATH,
model_path=Yolov5TestConstants.YOLOV5N_MODEL_PATH,
confidence_threshold=0.5,
device=None,
category_remapping=None,
Expand All @@ -89,20 +89,20 @@ def test_convert_original_predictions(self):
object_prediction_list = yolov5_detection_model.object_prediction_list

# compare
self.assertEqual(len(object_prediction_list), 9)
self.assertEqual(len(object_prediction_list), 8)
self.assertEqual(object_prediction_list[0].category.id, 2)
self.assertEqual(object_prediction_list[0].category.name, "car")
desired_bbox = [321, 322, 62, 40]
desired_bbox = [321, 325, 63, 40]
predicted_bbox = object_prediction_list[0].bbox.to_coco_bbox()
margin = 2
for ind, point in enumerate(predicted_bbox):
assert point < desired_bbox[ind] + margin and point > desired_bbox[ind] - margin
self.assertEqual(object_prediction_list[5].category.id, 2)
self.assertEqual(object_prediction_list[5].category.name, "car")
self.assertEqual(
object_prediction_list[5].bbox.to_coco_bbox(),
[617, 195, 24, 23],
)
desired_bbox = [701, 234, 20, 17]
predicted_bbox = object_prediction_list[5].bbox.to_coco_bbox()
for ind, point in enumerate(predicted_bbox):
assert point < desired_bbox[ind] + margin and point > desired_bbox[ind] - margin

def test_create_original_predictions_from_object_prediction_list(
self,
Expand Down

0 comments on commit 55b2641

Please sign in to comment.