Skip to content

Commit

Permalink
add concurrent slice export (#184)
Browse files Browse the repository at this point in the history
  • Loading branch information
fcakyon authored Aug 3, 2021
1 parent 283ed11 commit 5d7526a
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 9 deletions.
42 changes: 34 additions & 8 deletions sahi/slicing.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# OBSS SAHI Tool
# Code written by Fatih C Akyon, 2020.

import concurrent.futures
import os
import time
from pathlib import Path
Expand All @@ -14,6 +15,8 @@
from sahi.utils.cv import read_image_as_pil
from sahi.utils.file import create_dir, load_json, save_json

MAX_WORKERS = 20


def get_slice_bboxes(
image_height: int,
Expand Down Expand Up @@ -138,15 +141,15 @@ def __init__(self, image, coco_image, starting_pixel):


class SliceImageResult:
def __init__(self, original_image_size=None, image_dir=None):
def __init__(self, original_image_size=None, image_dir: str = None):
"""
sliced_image_list: list of SlicedImage
image_dir: str
Directory of the sliced image exports.
original_image_size: list of int
Size of the unsliced original image in [height, width]
"""
self._sliced_image_list = []
self._sliced_image_list: List[SlicedImage] = []
self.original_image_height = original_image_size[0]
self.original_image_width = original_image_size[1]
self.image_dir = image_dir
Expand Down Expand Up @@ -195,6 +198,18 @@ def starting_pixels(self) -> List[int]:
starting_pixels.append(sliced_image.starting_pixel)
return starting_pixels

@property
def filenames(self) -> List[int]:
"""Returns a list of filenames for each slice.
Returns:
filenames: a list of filenames as str
"""
filenames = []
for sliced_image in self._sliced_image_list:
filenames.append(sliced_image.coco_image.file_name)
return filenames

def __len__(self):
return len(self._sliced_image_list)

Expand Down Expand Up @@ -251,6 +266,13 @@ def slice_image(
# define verboseprint
verboseprint = print if verbose else lambda *a, **k: None

def _export_single_slice(image: np.ndarray, output_dir: str, slice_file_name: str):
image_pil = read_image_as_pil(image)
slice_file_path = str(Path(output_dir) / slice_file_name)
# export sliced image
image_pil.save(slice_file_path)
verboseprint("sliced image path:", slice_file_path)

# create outdir if not present
if output_dir:
create_dir(output_dir)
Expand Down Expand Up @@ -299,12 +321,6 @@ def slice_image(

# set image file name and path
slice_file_name = f"{output_file_name}_{slice_suffixes}{suffix}"
# export image if output directory is provided
if output_file_name and output_dir:
slice_file_path = str(Path(output_dir) / slice_file_name)
# export sliced image
image_pil_slice.save(slice_file_path)
verboseprint("sliced image path:", slice_file_path)

# create coco image
slice_width = slice_bbox[2] - slice_bbox[0]
Expand All @@ -324,6 +340,16 @@ def slice_image(
)
sliced_image_result.add_sliced_image(sliced_image)

# export slices if output directory is provided
if output_file_name and output_dir:
conc_exec = concurrent.futures.ThreadPoolExecutor(max_workers=MAX_WORKERS)
conc_exec.map(
_export_single_slice,
sliced_image_result.images,
[output_dir] * len(sliced_image_result),
sliced_image_result.filenames,
)

verboseprint(
"Num slices:",
n_ims,
Expand Down
1 change: 0 additions & 1 deletion scripts/slice_coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import os

from sahi.slicing import slice_coco
from sahi.utils.coco import Coco, split_coco_as_train_val
from sahi.utils.file import Path, get_base_filename, increment_path, save_json

if __name__ == "__main__":
Expand Down

0 comments on commit 5d7526a

Please sign in to comment.