Skip to content

Commit

Permalink
add obj detect model ui
Browse files Browse the repository at this point in the history
  • Loading branch information
vicentebolea committed Mar 22, 2024
1 parent 59c3af8 commit 42d2fbc
Show file tree
Hide file tree
Showing 1,008 changed files with 184 additions and 37 deletions.
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,14 @@ dependencies = [
"Pillow",
"scikit-learn==1.4.1.post1",
"smqtk-classifier==0.19.0",
"accelerate",
"smqtk-core==0.19.0",
"smqtk-dataprovider==0.18.0",
"smqtk-descriptors==0.19.0",
"smqtk-detection[torch,centernet]==0.20.1",
"smqtk-image-io==0.17.1",
"tabulate",
"transformers",
"timm",
"torch",
"torchvision",
Expand Down
31 changes: 26 additions & 5 deletions src/nrtk_explorer/app/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,8 @@

DIR_NAME = os.path.dirname(nrtk_explorer.test_data.__file__)
DEFAULT_DATASETS = [
f"{DIR_NAME}/coco-od-2017/mini_val2017.json",
f"{DIR_NAME}/OIRDS_v1_0/oirds.json",
f"{DIR_NAME}/OIRDS_v1_0/oirds_test.json",
f"{DIR_NAME}/OIRDS_v1_0/oirds_train.json",
]


Expand Down Expand Up @@ -97,14 +96,18 @@ def __init__(self, server=None):
self.state.client_only("horizontal_split", "vertical_split")

transforms_translator = Translator()
transforms_translator.add_translation("current_model", "current_transforms_model")
transforms_translator.add_translation(
"feature_extraction_model", "current_transforms_model"
)

self._transforms_app = TransformsApp(
server=self.server.create_child_server(translator=transforms_translator)
)

embeddings_translator = Translator()
embeddings_translator.add_translation("current_model", "current_embeddings_model")
embeddings_translator.add_translation(
"feature_extraction_model", "current_embeddings_model"
)

self._embeddings_app = EmbeddingsApp(
server=self.server.create_child_server(translator=embeddings_translator),
Expand Down Expand Up @@ -316,14 +319,32 @@ def ui(self, *args, **kwargs):
with transforms_actions_slot:
self._transforms_app.apply_ui()

with html.Div(classes="q-pa-md q-gutter-md"):
(
model_title_slot,
model_content_slot,
model_actions_slot,
) = collapsible_card("collapse_model")

with model_title_slot:
html.Span("Model Settings", classes="text-h6")

with model_content_slot:
self._transforms_app.model_widget()

with model_actions_slot:
self._transforms_app.apply_ui()

with html.Template(v_slot_after=True):
with quasar.QSplitter(
v_model=("vertical_split",),
limits=("[0,100]",),
horizontal=True,
classes="inherit-height zero-height",
before_class="q-pa-md",
after_class="q-pa-md",
):
) as splitter:
print(splitter)
with html.Template(v_slot_before=True):
self._embeddings_app.visualization_widget()

Expand Down
14 changes: 7 additions & 7 deletions src/nrtk_explorer/app/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,22 +40,22 @@ def __init__(self, server):
self.features = None

self.state.client_only("camera_position")
self.state.current_model = "resnet50.a1_in1k"
self.state.feature_extraction_model = "resnet50.a1_in1k"

self.server.controller.add("on_server_ready")(self.on_server_ready)
self.transformed_images_cache = {}

def on_server_ready(self, *args, **kwargs):
# Bind instance methods to state change
self.on_current_dataset_change()
self.on_current_model_change()
self.on_feature_extraction_model_change()
self.state.change("current_dataset")(self.on_current_dataset_change)
self.state.change("current_model")(self.on_current_model_change)
self.state.change("feature_extraction_model")(self.on_feature_extraction_model_change)

def on_current_model_change(self, **kwargs):
current_model = self.state.current_model
def on_feature_extraction_model_change(self, **kwargs):
feature_extraction_model = self.state.feature_extraction_model
self.extractor = embeddings_extractor.EmbeddingsExtractor(
model_name=current_model, manager=self.context.images_manager
model_name=feature_extraction_model, manager=self.context.images_manager
)

def on_current_dataset_change(self, **kwargs):
Expand Down Expand Up @@ -228,7 +228,7 @@ def settings_widget(self):

quasar.QSelect(
label="Embeddings Model",
v_model=("current_model",),
v_model=("feature_extraction_model",),
options=(
[
{"label": "ResNet50", "value": "resnet50.a1_in1k"},
Expand Down
90 changes: 67 additions & 23 deletions src/nrtk_explorer/app/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import nrtk_explorer.library.transforms as trans
import nrtk_explorer.library.nrtk_transforms as nrtk_trans
from nrtk_explorer.library import object_detector
from nrtk_explorer.library import images_manager
from nrtk_explorer.app.ui.image_list import image_list_component
from nrtk_explorer.app.applet import Applet
Expand Down Expand Up @@ -79,7 +80,7 @@ def __init__(self, server):

self._on_transform_fn = None
self.state.models = [k for k in self.models.keys()]
self.state.current_model = self.state.models[0]
self.state.feature_extraction_model = self.state.models[0]

self._transforms: Dict[str, trans.ImageTransform] = {
"identity": trans.IdentityTransform(),
Expand All @@ -106,10 +107,11 @@ def __init__(self, server):

self.server.controller.add("on_server_ready")(self.on_server_ready)
self._on_hover_fn = None
self.detector = object_detector.ObjectDetector(model_name="hustvl/yolos-tiny")

def on_server_ready(self, *args, **kwargs):
# Bind instance methods to state change
self.state.change("current_model")(self.on_current_model_change)
self.state.change("feature_extraction_model")(self.on_feature_extraction_model_change)
self.state.change("current_dataset")(self.on_current_dataset_change)
self.state.change("current_num_elements")(self.on_current_num_elements_change)

Expand All @@ -128,7 +130,6 @@ def on_apply_transform(self, *args, **kwargs):
current_transform = self.state.current_transform
transformed_image_ids = []
transform = self._transforms[current_transform]

for image_id in self.state.source_image_ids:
image = self.context["image_objects"][image_id]

Expand All @@ -148,13 +149,46 @@ def on_apply_transform(self, *args, **kwargs):
self.state.hovered_id = -1

self.state.transformed_image_ids = transformed_image_ids

self.update_model_result(self.state.transformed_image_ids, self.state.current_model)
self.compute_annotations(transformed_image_ids)

# Only invoke callbacks when we transform images
if len(transformed_image_ids) > 0:
self.on_transform(transformed_image_ids)

def compute_annotations(self, ids):
"""Compute annotations for the given image ids using the object detector model."""
if len(ids) == 0:
return

for id_ in ids:
self.context["annotations"][id_] = []

prediction = self.detector.eval(paths=ids, content=self.context.image_objects)

for id_, annotations in zip(ids, prediction):
image_annotations = self.context["annotations"].setdefault(id_, [])
for prediction in annotations:
category_id = 0
for cat_id, cat in self.state.annotation_categories.items():
if cat["name"] == prediction["label"]:
category_id = cat_id

bbox = prediction["box"]
image_annotations.append(
{
"category_id": category_id,
"id": category_id,
"bbox": [
bbox["xmin"],
bbox["ymin"],
bbox["xmax"] - bbox["xmin"],
bbox["ymax"] - bbox["ymin"],
],
}
)

self.update_model_result(ids, self.state.feature_extraction_model)

def on_current_num_elements_change(self, current_num_elements, **kwargs):
with open(self.state.current_dataset) as f:
dataset = json.load(f)
Expand Down Expand Up @@ -183,7 +217,7 @@ def on_selected_images_change(self, selected_ids):

image_filename = os.path.join(current_dir, image_metadata["file_name"])

img = self.context.images_manager.load_thumbnail(image_filename)
img = self.context.images_manager.load_image(image_filename)

self.state[image_id] = images_manager.convert_to_base64(img)
self.state[meta_id] = {
Expand All @@ -195,8 +229,8 @@ def on_selected_images_change(self, selected_ids):
self.context.image_objects[image_id] = img

self.state.source_image_ids = source_image_ids

self.update_model_result(self.state.source_image_ids, self.state.current_model)
self.compute_annotations(source_image_ids)
self.update_model_result(self.state.source_image_ids, self.state.feature_extraction_model)
self.on_apply_transform()

def reset_data(self):
Expand Down Expand Up @@ -260,27 +294,18 @@ def on_current_dataset_change(self, current_dataset, **kwargs):
for i, image in enumerate(dataset["images"]):
self.context.image_id_to_index[image["id"]] = i

for annotation in dataset["annotations"]:
image_id = f"img_{annotation['image_id']}"
image_annotations = self.context["annotations"].setdefault(image_id, [])
image_annotations.append(annotation)

transformed_image_id = f"transformed_{image_id}"
image_annotations = self.context["annotations"].setdefault(transformed_image_id, [])
image_annotations.append(annotation)

if self.is_standalone_app:
self.context.images_manager = images_manager.ImagesManager()

def on_current_model_change(self, **kwargs):
logger.info(f">>> ENGINE(a): on_current_model_change change {self.state}")
def on_feature_extraction_model_change(self, **kwargs):
logger.info(f">>> ENGINE(a): on_feature_extraction_model_change change {self.state}")

current_model = self.state.current_model
feature_extraction_model = self.state.feature_extraction_model

self.update_model_result(self.state.source_image_ids, current_model)
self.update_model_result(self.state.transformed_image_ids, current_model)
self.update_model_result(self.state.source_image_ids, feature_extraction_model)
self.update_model_result(self.state.transformed_image_ids, feature_extraction_model)

def update_model_result(self, image_ids, current_model):
def update_model_result(self, image_ids, feature_extraction_model):
for image_id in image_ids:
result_id = image_id_to_result(image_id)
self.state[result_id] = self.context["annotations"].get(image_id, [])
Expand Down Expand Up @@ -309,6 +334,25 @@ def settings_widget(self):
):
self._parameters_app.transform_params_ui()

def model_widget(self):
with html.Div(trame_server=self.server, classes="col"):
with html.Div(classes="q-gutter-y-md"):
quasar.QSelect(
label="Object detection Model",
v_model=("object_detection_model", "facebook/detr-resnet-50"),
options=(
[
{
"label": "facebook/detr-resnet-50",
"value": "facebook/detr-resnet-50",
},
],
),
filled=True,
emit_value=True,
map_options=True,
)

def apply_ui(self):
with html.Div(trame_server=self.server):
self._parameters_app.transform_apply_ui()
Expand Down
5 changes: 3 additions & 2 deletions src/nrtk_explorer/library/images_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from PIL import Image as ImageModule

import base64
import copy
import io

# Resolution for images to be used in model
Expand Down Expand Up @@ -39,15 +40,15 @@ def load_image(self, path):
def load_image_for_model(self, path):
"""Load image for model from path and store it in cache if not already loaded"""
if path not in self.images_for_model:
img = self.load_thumbnail(path)
img = copy.copy(self.load_image(path))
self.images_for_model[path] = self.prepare_for_model(img)

return self.images_for_model[path]

def load_thumbnail(self, path):
"""Load thumbnail from path and store it in cache if not already loaded"""
if path not in self.thumbnails:
img = self.load_image(path)
img = copy.copy(self.load_image(path))
img.thumbnail(THUMBNAIL_RESOLUTION)
self.thumbnails[path] = img

Expand Down
59 changes: 59 additions & 0 deletions src/nrtk_explorer/library/object_detector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import torch
import logging

from nrtk_explorer.library import images_manager
from transformers import pipeline


class ObjectDetector:
def __init__(
self,
model_name="facebook/detr-resnet-50",
manager=images_manager.ImagesManager(),
force_cpu=False,
):
self.manager = manager
self.device = "cuda" if torch.cuda.is_available() and not force_cpu else "cpu"
self.pipeline = model_name

@property
def device(self):
return self._device

@device.setter
def device(self, dev):
logging.info(f"Using {dev} devices for feature extraction")
self._device = torch.device(dev)

@property
def pipeline(self):
return self._pipeline

@pipeline.setter
def pipeline(self, model_name):
self._pipeline = pipeline(model=model_name, device=self.device)

def eval(self, paths, labels=None, content=None, batch_size=32):
"""Compute object recognition, return it in a dictionary of COCO format"""
if len(paths) == 0:
return None

images = dict()

# Group images by size (shape)
for path in paths:
img = None
if content and path in content:
img = content[path]
else:
img = self.manager.load_image(path)

images.setdefault(img.size, []).append(img)

print(images.keys())
# Call by each grup
output = list()
for group in images.values():
output += self.pipeline(group, batch_size=batch_size)

return output
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit 42d2fbc

Please sign in to comment.