Skip to content

Commit

Permalink
refactor scripts (#333)
Browse files Browse the repository at this point in the history
* predict function optionally returns export_dir

* update to v0.8.19

* refactor some function names

* minor fix fo coco2fiftyone

* coco error analysis optionally returns export paths

* simplify predict/image path list generation

* improve compatibility with earlier versions
  • Loading branch information
fcakyon authored Jan 6, 2022
1 parent b88aef5 commit 7f6daa8
Show file tree
Hide file tree
Showing 8 changed files with 117 additions and 52 deletions.
2 changes: 1 addition & 1 deletion sahi/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.8.18"
__version__ = "0.8.19"
8 changes: 4 additions & 4 deletions sahi/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
from sahi.predict import predict, predict_fiftyone
from sahi.scripts.coco2fiftyone import main as coco2fiftyone
from sahi.scripts.coco2yolov5 import main as coco2yolov5
from sahi.scripts.coco_error_analysis import main as coco_error_analysis
from sahi.scripts.coco_evaluation import main as coco_evaluation
from sahi.scripts.coco_error_analysis import analyse
from sahi.scripts.coco_evaluation import evaluate
from sahi.scripts.slice_coco import main as slice_coco

coco_app = {
"evaluate": coco_evaluation,
"analyse": coco_error_analysis,
"evaluate": evaluate,
"analyse": analyse,
"fiftyone": coco2fiftyone,
"slice": slice_coco,
"yolov5": coco2yolov5,
Expand Down
13 changes: 5 additions & 8 deletions sahi/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import numpy as np

from sahi.prediction import ObjectPrediction
from sahi.utils.compatibility import fix_full_shape_list, fix_shift_amount_list
from sahi.utils.torch import cuda_is_available, empty_cuda_cache

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -283,10 +284,8 @@ def _create_object_prediction_list_from_original_predictions(
category_mapping = self.category_mapping

# compatilibty for sahi v0.8.15
if isinstance(shift_amount_list[0], int):
shift_amount_list = [shift_amount_list]
if full_shape_list is not None and isinstance(full_shape_list[0], int):
full_shape_list = [full_shape_list]
shift_amount_list = fix_shift_amount_list(shift_amount_list)
full_shape_list = fix_full_shape_list(full_shape_list)

# parse boxes and masks from predictions
num_categories = self.num_categories
Expand Down Expand Up @@ -449,10 +448,8 @@ def _create_object_prediction_list_from_original_predictions(
original_predictions = self._original_predictions

# compatilibty for sahi v0.8.15
if isinstance(shift_amount_list[0], int):
shift_amount_list = [shift_amount_list]
if full_shape_list is not None and isinstance(full_shape_list[0], int):
full_shape_list = [full_shape_list]
shift_amount_list = fix_shift_amount_list(shift_amount_list)
full_shape_list = fix_full_shape_list(full_shape_list)

# handle all predictions
object_prediction_list_per_image = []
Expand Down
12 changes: 7 additions & 5 deletions sahi/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,7 @@ def predict(
visual_text_thickness: int = None,
visual_export_format: str = "png",
verbose: int = 1,
return_dict: bool = False,
):
"""
Performs prediction for all present images in given folder.
Expand Down Expand Up @@ -390,6 +391,8 @@ def predict(
0: no print
1: print slice/prediction durations, number of slices
2: print model loading/file exporting durations
return_dict: bool
If True, returns a dict with 'export_dir' field.
"""
# assert prediction type
assert (
Expand All @@ -405,17 +408,13 @@ def predict(
image_path_list = [str(Path(source) / Path(coco_image.file_name)) for coco_image in coco.images]
coco_json = []
elif os.path.isdir(source):
time_start = time.time()
image_path_list = list_files(
directory=source,
contains=[".jpg", ".jpeg", ".png"],
contains=[".jpg", ".jpeg", ".png", ".tiff", ".bmp"],
verbose=verbose,
)
time_end = time.time() - time_start
durations_in_seconds["list_files"] = time_end
else:
image_path_list = [source]
durations_in_seconds["list_files"] = 0

# init export directories
save_dir = Path(increment_path(Path(project) / name, exist_ok=False)) # increment run
Expand Down Expand Up @@ -600,6 +599,9 @@ def predict(
"seconds.",
)

if return_dict:
return {"export_dir": save_dir}


def predict_fiftyone(
model_type: str = "mmdet",
Expand Down
4 changes: 2 additions & 2 deletions sahi/scripts/coco2fiftyone.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def main(
image_dir (str): directory for coco images
dataset_json_path (str): file path for the coco dataset json file
result_json_paths (str): one or more paths for the coco result json file
iou_thresh (str): iou threshold for coco evaluation
iou_thresh (float): iou threshold for coco evaluation
"""

from sahi.utils.fiftyone import create_fiftyone_dataset_from_coco_file, fo
Expand Down Expand Up @@ -91,7 +91,7 @@ def main(
gt_field="ground_truth",
eval_key=f"{first_coco_result_name}_eval",
iou=iou_thresh,
compute_mAP=True,
compute_mAP=False,
)
# Get the 10 most common classes in the dataset
counts = dataset.count_values("ground_truth.detections.label")
Expand Down
83 changes: 65 additions & 18 deletions sahi/scripts/coco_error_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
def _makeplot(rs, ps, outDir, class_name, iou_type):
import matplotlib.pyplot as plt

export_path_list = []

cs = np.vstack(
[
np.ones((2, 3)),
Expand Down Expand Up @@ -47,15 +49,19 @@ def _makeplot(rs, ps, outDir, class_name, iou_type):
plt.title(figure_title)
plt.legend()
# plt.show()
fig.savefig(outDir + f"/{figure_title}.png")
export_path = str(Path(outDir) / f"{figure_title}.png")
fig.savefig(export_path)
plt.close(fig)

export_path_list.append(export_path)
return export_path_list


def _autolabel(ax, rects):
def _autolabel(ax, rects, is_percent=True):
"""Attach a text label above each bar in *rects*, displaying its height."""
for rect in rects:
height = rect.get_height()
if height > 0 and height <= 1: # for percent values
if is_percent and height > 0 and height <= 1: # for percent values
text_label = "{:2.0f}".format(height * 100)
else:
text_label = "{:2.0f}".format(height)
Expand Down Expand Up @@ -104,9 +110,12 @@ def _makebarplot(rs, ps, outDir, class_name, iou_type):
_autolabel(ax, rects)

# Save plot
fig.savefig(outDir + f"/{figure_title}.png")
export_path = str(Path(outDir) / f"{figure_title}.png")
fig.savefig(export_path)
plt.close(fig)

return export_path


def _get_gt_area_group_numbers(cocoEval):
areaRng = cocoEval.params.areaRng
Expand Down Expand Up @@ -146,13 +155,16 @@ def _make_gt_area_group_numbers_plot(cocoEval, outDir, verbose=True):
ax.set_xticklabels(areaRngLbl)

# Add score texts over bars
_autolabel(ax, rects)
_autolabel(ax, rects, is_percent=False)

# Save plot
export_path = str(Path(outDir) / f"{figure_title}.png")
fig.tight_layout()
fig.savefig(outDir + f"/{figure_title}.png")
fig.savefig(export_path)
plt.close(fig)

return export_path


def _make_gt_area_histogram_plot(cocoEval, outDir):
import matplotlib.pyplot as plt
Expand All @@ -173,10 +185,13 @@ def _make_gt_area_histogram_plot(cocoEval, outDir):
ax.set_title(figure_title)

# Save plot
export_path = str(Path(outDir) / f"{figure_title}.png")
fig.tight_layout()
fig.savefig(outDir + f"/{figure_title}.png")
fig.savefig(export_path)
plt.close(fig)

return export_path


def _analyze_individual_category(k, cocoDt, cocoGt, catId, iou_type, areas=None, max_detections=None, COCOeval=None):
nm = cocoGt.loadCats(catId)[0]
Expand Down Expand Up @@ -242,7 +257,7 @@ def _analyze_individual_category(k, cocoDt, cocoGt, catId, iou_type, areas=None,
return k, ps_


def _analyze_results(
def _analyse_results(
res_file,
ann_file,
res_types,
Expand Down Expand Up @@ -270,6 +285,8 @@ def _analyze_results(
print(f"-------------create {out_dir}-----------------")
os.makedirs(directory)

result_type_to_export_paths = {}

cocoGt = COCO(ann_file)
cocoDt = cocoGt.loadRes(res_file)
imgIds = cocoGt.getImgIds()
Expand Down Expand Up @@ -315,6 +332,8 @@ def _analyze_results(
for k, catId in enumerate(present_cat_ids)
]
analyze_results = pool.starmap(_analyze_individual_category, args)

classname_to_export_path_list = {}
for k, catId in enumerate(present_cat_ids):

nm = cocoGt.loadCats(catId)[0]
Expand All @@ -332,35 +351,61 @@ def _analyze_results(
ps[5, :, k, :, :][ps[4, :, k, :, :] > 0] = 1
ps[6, :, k, :, :] = 1.0

_makeplot(recThrs, ps[:, :, k], res_out_dir, nm["name"], iou_type)
roc_curve_export_path_list = _makeplot(recThrs, ps[:, :, k], res_out_dir, nm["name"], iou_type)

if extraplots:
_makebarplot(recThrs, ps[:, :, k], res_out_dir, nm["name"], iou_type)
_makeplot(recThrs, ps, res_out_dir, "allclass", iou_type)
bar_plot_path = _makebarplot(recThrs, ps[:, :, k], res_out_dir, nm["name"], iou_type)
else:
bar_plot_path = None
classname_to_export_path_list[nm["name"]] = {
"roc_curves": roc_curve_export_path_list,
"bar_plot": bar_plot_path,
}

roc_curve_export_path_list = _makeplot(recThrs, ps, res_out_dir, "allclass", iou_type)
if extraplots:
_makebarplot(recThrs, ps, res_out_dir, "allclass", iou_type)
_make_gt_area_group_numbers_plot(cocoEval=cocoEval, outDir=res_out_dir, verbose=True)
_make_gt_area_histogram_plot(cocoEval=cocoEval, outDir=res_out_dir)
bar_plot_path = _makebarplot(recThrs, ps, res_out_dir, "allclass", iou_type)
gt_area_group_numbers_plot_path = _make_gt_area_group_numbers_plot(
cocoEval=cocoEval, outDir=res_out_dir, verbose=True
)
gt_area_histogram_plot_path = _make_gt_area_histogram_plot(cocoEval=cocoEval, outDir=res_out_dir)
else:
bar_plot_path, gt_area_group_numbers_plot_path, gt_area_histogram_plot_path = None, None, None

result_type_to_export_paths[res_type] = {
"classwise": classname_to_export_path_list,
"overall": {
"bar_plot": bar_plot_path,
"roc_curves": roc_curve_export_path_list,
"gt_area_group_numbers": gt_area_group_numbers_plot_path,
"gt_area_histogram": gt_area_histogram_plot_path,
},
}
print(f"COCO error analysis results are successfully exported to {out_dir}")

return result_type_to_export_paths


def main(
def analyse(
dataset_json_path: str,
result_json_path: str,
out_dir: str = None,
type: str = "bbox",
extraplots: bool = False,
areas: List[int] = [1024, 9216, 10000000000],
max_detections: int = 500,
return_dict: bool = False,
):
"""
Args:
dataset_json_path (str): file path for the coco dataset json file
result_json_paths (str): file path for the coco result json file
out_dir (str): dir to save analyze result images
out_dir (str): dir to save analyse result images
extraplots (bool): export extra bar/stat plots
type (str): 'bbox' or 'mask'
areas (List[int]): area regions for coco evaluation calculations
max_detections (int): Maximum number of detections to consider for AP alculation. Default: 500
return_dict (bool): If True, returns a dict export paths.
"""
try:
from pycocotools.coco import COCO
Expand All @@ -376,7 +421,7 @@ def main(
'Please run "pip install -U matplotlib" ' "to install matplotlib first for visualization."
)

_analyze_results(
result = _analyse_results(
result_json_path,
dataset_json_path,
res_types=[type],
Expand All @@ -387,7 +432,9 @@ def main(
COCO=COCO,
COCOeval=COCOeval,
)
if return_dict:
return result


if __name__ == "__main__":
fire.Fire(main)
fire.Fire(analyse)
Loading

0 comments on commit 7f6daa8

Please sign in to comment.