Skip to content

Commit

Permalink
refactor: fix typing on object_detector and nrtk_transforms
Browse files Browse the repository at this point in the history
  • Loading branch information
PaulHax committed Jun 13, 2024
1 parent 64f01e1 commit da1b755
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 23 deletions.
2 changes: 1 addition & 1 deletion src/nrtk_explorer/app/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def compute_annotations(self, ids):
if len(ids) == 0:
return

predictions = self.detector.eval(paths=ids, content=self.context.image_objects)
predictions = self.detector.eval(image_ids=ids, content=self.context.image_objects)

for id_, annotations in predictions:
image_annotations = self.context["annotations"].setdefault(id_, [])
Expand Down
6 changes: 3 additions & 3 deletions src/nrtk_explorer/library/nrtk_transforms.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict
from typing import Any, Dict, Union

import numpy as np
import logging
Expand All @@ -25,7 +25,7 @@ def nrtk_transforms_available():


class NrtkGaussianBlurTransform(ImageTransform):
def __init__(self, perturber=None):
def __init__(self, perturber: Union[GaussianBlurPerturber, None] = None):
if perturber is None:
perturber = GaussianBlurPerturber()

Expand Down Expand Up @@ -155,7 +155,7 @@ def createSampleSensorAndScenario():


class NrtkPybsmTransform(ImageTransform):
def __init__(self, perturber=None):
def __init__(self, perturber: Union[PybsmPerturber, None] = None):
if perturber is None:
sensor, scenario = createSampleSensorAndScenario()
perturber = PybsmPerturber(sensor=sensor, scenario=scenario)
Expand Down
35 changes: 18 additions & 17 deletions src/nrtk_explorer/library/object_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@

from nrtk_explorer.library import images_manager

Annotations = list[list[tuple[str, dict]]]
Annotation = dict # in COCO format
Annotations = list[Annotation]
ImageId = str
AnnotatedImage = tuple[ImageId, Annotations]
AnnotatedImages = list[AnnotatedImage]


class ObjectDetector:
Expand Down Expand Up @@ -59,19 +63,16 @@ def pipeline(self, model_name: str):

def eval(
self,
paths: list[str],
image_ids: list[str],
content: Optional[dict] = None,
batch_size: int = 32,
) -> Annotations:
"""Compute object recognition, return it in a list of tuples in the form of [(path, annotations dict in COCO Format)]"""
if len(paths) == 0:
return []

) -> AnnotatedImages:
"""Compute object recognition. Returns Annotations grouped by input image paths."""
images: dict = {}

# Some models require all the images in a batch to be the same size,
# otherwise crash or UB.
for path in paths:
for path in image_ids:
img = None
if content and path in content:
img = content[path]
Expand All @@ -93,12 +94,12 @@ def eval(
for group in images.values()
]
# Flatten the list of predictions
predictions = reduce(operator.iadd, predictions) # type: ignore

output = list()
for path in paths:
for prediction in predictions:
if prediction[0] == path:
output.append(prediction)

return output
predictions = reduce(operator.iadd, predictions, [])

# order output by paths order
find_prediction = lambda id: next(
prediction for prediction in predictions if prediction[0] == id
)
output = [find_prediction(path) for path in image_ids]
# mypy wrongly thinks output's type is list[list[tuple[str, dict]]]
return output # type: ignore
4 changes: 2 additions & 2 deletions tests/test_object_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@ def test_detector_small():
ds = json.load(open(DATASET))
sample = [f"{DATASET_PATH}/{img['file_name']}" for img in ds["images"]][:15]
detector = object_detector.ObjectDetector(model_name="hustvl/yolos-tiny")
img = detector.eval(paths=sample)
img = detector.eval(image_ids=sample)
assert len(img) == 15


def test_nrkt_scorer():
ds = json.load(open(DATASET))
sample = [f"{DATASET_PATH}/{img['file_name']}" for img in ds["images"]]
detector = object_detector.ObjectDetector(model_name="facebook/detr-resnet-50")
predictions = detector.eval(paths=sample)
predictions = detector.eval(image_ids=sample)

dataset_annotations = dict()
for annotation in ds["annotations"]:
Expand Down

0 comments on commit da1b755

Please sign in to comment.