diff --git a/src/nrtk_explorer/app/transforms.py b/src/nrtk_explorer/app/transforms.py index d1324159..72e911c9 100644 --- a/src/nrtk_explorer/app/transforms.py +++ b/src/nrtk_explorer/app/transforms.py @@ -204,7 +204,7 @@ def compute_annotations(self, ids): if len(ids) == 0: return - predictions = self.detector.eval(paths=ids, content=self.context.image_objects) + predictions = self.detector.eval(image_ids=ids, content=self.context.image_objects) for id_, annotations in predictions: image_annotations = self.context["annotations"].setdefault(id_, []) diff --git a/src/nrtk_explorer/library/nrtk_transforms.py b/src/nrtk_explorer/library/nrtk_transforms.py index be891c61..d064cb8c 100644 --- a/src/nrtk_explorer/library/nrtk_transforms.py +++ b/src/nrtk_explorer/library/nrtk_transforms.py @@ -1,4 +1,4 @@ -from typing import Any, Dict +from typing import Any, Dict, Union import numpy as np import logging @@ -25,7 +25,7 @@ def nrtk_transforms_available(): class NrtkGaussianBlurTransform(ImageTransform): - def __init__(self, perturber=None): + def __init__(self, perturber: Union[GaussianBlurPerturber, None] = None): if perturber is None: perturber = GaussianBlurPerturber() @@ -155,7 +155,7 @@ def createSampleSensorAndScenario(): class NrtkPybsmTransform(ImageTransform): - def __init__(self, perturber=None): + def __init__(self, perturber: Union[PybsmPerturber, None] = None): if perturber is None: sensor, scenario = createSampleSensorAndScenario() perturber = PybsmPerturber(sensor=sensor, scenario=scenario) diff --git a/src/nrtk_explorer/library/object_detector.py b/src/nrtk_explorer/library/object_detector.py index 945726bd..72eaccae 100644 --- a/src/nrtk_explorer/library/object_detector.py +++ b/src/nrtk_explorer/library/object_detector.py @@ -8,7 +8,11 @@ from nrtk_explorer.library import images_manager -Annotations = list[list[tuple[str, dict]]] +Annotation = dict # in COCO format +Annotations = list[Annotation] +ImageId = str +AnnotatedImage = tuple[ImageId, Annotations] +AnnotatedImages = list[AnnotatedImage] class ObjectDetector: @@ -59,19 +63,16 @@ def pipeline(self, model_name: str): def eval( self, - paths: list[str], + image_ids: list[str], content: Optional[dict] = None, batch_size: int = 32, - ) -> Annotations: - """Compute object recognition, return it in a list of tuples in the form of [(path, annotations dict in COCO Format)]""" - if len(paths) == 0: - return [] - + ) -> AnnotatedImages: + """Compute object recognition. Returns Annotations grouped by input image paths.""" images: dict = {} # Some models require all the images in a batch to be the same size, # otherwise crash or UB. - for path in paths: + for path in image_ids: img = None if content and path in content: img = content[path] @@ -93,12 +94,12 @@ def eval( for group in images.values() ] # Flatten the list of predictions - predictions = reduce(operator.iadd, predictions) # type: ignore - - output = list() - for path in paths: - for prediction in predictions: - if prediction[0] == path: - output.append(prediction) - - return output + predictions = reduce(operator.iadd, predictions, []) + + # order output by paths order + find_prediction = lambda id: next( + prediction for prediction in predictions if prediction[0] == id + ) + output = [find_prediction(path) for path in image_ids] + # mypy wrongly thinks output's type is list[list[tuple[str, dict]]] + return output # type: ignore diff --git a/tests/test_object_detector.py b/tests/test_object_detector.py index 72cfb238..ab2e30f7 100644 --- a/tests/test_object_detector.py +++ b/tests/test_object_detector.py @@ -18,7 +18,7 @@ def test_detector_small(): ds = json.load(open(DATASET)) sample = [f"{DATASET_PATH}/{img['file_name']}" for img in ds["images"]][:15] detector = object_detector.ObjectDetector(model_name="hustvl/yolos-tiny") - img = detector.eval(paths=sample) + img = detector.eval(image_ids=sample) assert len(img) == 15 @@ -26,7 +26,7 @@ def test_nrkt_scorer(): ds = json.load(open(DATASET)) sample = [f"{DATASET_PATH}/{img['file_name']}" for img in ds["images"]] detector = object_detector.ObjectDetector(model_name="facebook/detr-resnet-50") - predictions = detector.eval(paths=sample) + predictions = detector.eval(image_ids=sample) dataset_annotations = dict() for annotation in ds["annotations"]: