Skip to content

Commit

Permalink
refactor(image-list): add image server
Browse files Browse the repository at this point in the history
Keeps original image off state

Remove commented code
  • Loading branch information
PaulHax committed Jul 15, 2024
1 parent 65e5c8f commit 16a953d
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 29 deletions.
55 changes: 55 additions & 0 deletions src/nrtk_explorer/app/image_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from aiohttp import web
from PIL import Image
import io
from trame.app import get_server
from nrtk_explorer.library.dataset import get_image_path


ORIGINAL_IMAGE_ENDPOINT = "original-image"

server = get_server()


def is_browser_compatible_image(file_path):
# Check if the image format is compatible with web browsers
compatible_formats = {"jpg", "jpeg", "png", "gif", "webp"}
return file_path.split(".")[-1].lower() in compatible_formats


def make_response(image, format):
bytes_io = io.BytesIO()
image.save(bytes_io, format=format)
bytes_io.seek(0)
return web.Response(body=bytes_io.read(), content_type=f"image/{format.lower()}")


async def original_image_endpoint(request: web.Request):
id = request.match_info["id"]
image_path = get_image_path(id)

if image_path in server.context.images_manager.images:
image = server.context.images_manager.images[image_path]
send_format = "PNG"
if is_browser_compatible_image(image.format):
send_format = image.format.upper()
return make_response(image, send_format)

if is_browser_compatible_image(image_path):
return web.FileResponse(image_path)
else:
image = Image.open(image_path)
return make_response(image, "PNG")


image_routes = [
web.get(f"/{ORIGINAL_IMAGE_ENDPOINT}/{{id}}", original_image_endpoint),
]


def app_available(wslink_server):
"""Add our custom REST endpoints to the trame server."""
wslink_server.app.add_routes(image_routes)


# --hot-reload does not work if this is configured as decorator on the function
server.controller.add("on_server_bind")(app_available)
3 changes: 2 additions & 1 deletion src/nrtk_explorer/app/trame_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,6 @@ async def __aenter__(self):
async def __aexit__(self, exc_type, exc, tb):
self.state.flush()
await asyncio.sleep(0)
await asyncio.sleep(0.1)
await asyncio.sleep(0)
await asyncio.sleep(0)
await asyncio.sleep(0)
26 changes: 7 additions & 19 deletions src/nrtk_explorer/app/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@
dataset_id_to_image_id,
dataset_id_to_transformed_image_id,
)
from nrtk_explorer.library.dataset import get_dataset
from nrtk_explorer.library.dataset import get_dataset, get_image_path
import nrtk_explorer.app.image_server


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -114,7 +115,6 @@ def __init__(self, server):

def on_server_ready(self, *args, **kwargs):
# Bind instance methods to state change
# self.state.change("feature_extraction_model")(self.on_feature_extraction_model_change)
self.state.change("current_dataset")(self.on_current_dataset_change)
self.state.change("current_num_elements")(self.on_current_num_elements_change)

Expand Down Expand Up @@ -265,20 +265,16 @@ def compute_predictions_source_images(self, ids):
)

async def _update_images(self):
loading = len(self.context.selected_dataset_ids) > 0
selected_ids = self.context.selected_dataset_ids
loading = len(selected_ids) > 0
async with SetStateAsync(self.state):
self.state.loading_images = loading
self.state.hovered_id = ""

selected_ids = self.context.selected_dataset_ids
current_dir = os.path.dirname(self.state.current_dataset)

for selected_id in selected_ids:
image_metadata = self.context.dataset.imgs[int(selected_id)]
image_filename = os.path.join(current_dir, image_metadata["file_name"])
img = self.context.images_manager.load_image(image_filename)
filename = get_image_path(selected_id)
img = self.context.images_manager.load_image(filename)
image_id = dataset_id_to_image_id(selected_id)
self.state[image_id] = images_manager.convert_to_base64(img)
self.context.image_objects[image_id] = img

async with SetStateAsync(self.state):
Expand All @@ -287,10 +283,8 @@ async def _update_images(self):
self.state[image_id_to_result_id(id)] = None
self.state[image_id_to_result_id(dataset_id_to_image_id(id))] = None
self.state[image_id_to_result_id(dataset_id_to_transformed_image_id(id))] = None

async with SetStateAsync(self.state):
self.state.source_image_ids = [dataset_id_to_image_id(id) for id in selected_ids]
self.state.loading_images = False
self.state.loading_images = False # remove big spinner and show table

async with SetStateAsync(self.state):
self.load_ground_truth_annotations(selected_ids)
Expand Down Expand Up @@ -352,12 +346,6 @@ def on_current_dataset_change(self, current_dataset, **kwargs):
if self.is_standalone_app:
self.context.images_manager = images_manager.ImagesManager()

# No GUI, not tested
# def on_feature_extraction_model_change(self, **kwargs):
# logger.debug(f">>> on_feature_extraction_model_change change {self.state}")
# self.delete_computed_image_data()
# self._start_update_images()

def on_image_hovered(self, id):
self.state.hovered_id = id

Expand Down
14 changes: 9 additions & 5 deletions src/nrtk_explorer/app/ui/image_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def __init__(self, hover_fn=None, **kwargs):
ground_truth_to_transformed_detection_score: meta.ground_truth_to_transformed_detection_score.toFixed(2),
original_detection_to_transformed_detection_score: meta.original_detection_to_transformed_detection_score.toFixed(2),
id: datasetId,
original: id,
original: `original-image/${datasetId}`,
transformed: `transformed_${id}`,
groundTruthAnnotations: get(`result_${datasetId}`),
originalAnnotations: get(`result_${id}`),
Expand All @@ -66,6 +66,10 @@ def __init__(self, hover_fn=None, **kwargs):
),
row_key="id",
rows_per_page_options=("[0]",), # [0] means show all rows
virtual_scroll=True,
__properties=[
("virtual_scroll", "virtual-scroll"),
],
):
# ImageDetection component for image columns
with html.Template(
Expand All @@ -76,7 +80,7 @@ def __init__(self, hover_fn=None, **kwargs):
ImageDetection(
style="max-width: 10rem; float: inline-end;",
identifier=("props.row.original",),
src=("get(props.row.original)",),
src=("props.row.original",),
annotations=("props.row.groundTruthAnnotations",),
categories=("annotation_categories",),
selected=("(props.row.original == hovered_id)",),
Expand All @@ -93,7 +97,7 @@ def __init__(self, hover_fn=None, **kwargs):
ImageDetection(
style="max-width: 10rem; float: inline-end;",
identifier=("props.row.original",),
src=("get(props.row.original)",),
src=("props.row.original",),
annotations=("props.row.originalAnnotations",),
categories=("annotation_categories",),
selected=("(props.row.original == hovered_id)",),
Expand Down Expand Up @@ -132,7 +136,7 @@ def __init__(self, hover_fn=None, **kwargs):
)
ImageDetection(
identifier=("props.row.original",),
src=("get(props.row.original)",),
src=("props.row.original",),
annotations=("props.row.groundTruthAnnotations",),
categories=("annotation_categories",),
selected=("(props.row.original == hovered_id)",),
Expand All @@ -145,7 +149,7 @@ def __init__(self, hover_fn=None, **kwargs):
)
ImageDetection(
identifier=("props.row.original",),
src=("get(props.row.original)",),
src=("props.row.original",),
annotations=("props.row.originalAnnotations",),
categories=("annotation_categories",),
selected=("(props.row.original == hovered_id)",),
Expand Down
15 changes: 11 additions & 4 deletions src/nrtk_explorer/library/dataset.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,24 @@
import kwcoco
from pathlib import Path


def load_dataset(path: str):
return kwcoco.CocoDataset(path)


dataset_json: kwcoco.CocoDataset = kwcoco.CocoDataset()
dataset: kwcoco.CocoDataset = kwcoco.CocoDataset()
dataset_path: str = ""


def get_dataset(path: str, force_reload=False):
global dataset_json, dataset_path
global dataset, dataset_path
if dataset_path != path or force_reload:
dataset_path = path
dataset_json = load_dataset(dataset_path)
return dataset_json
dataset = load_dataset(dataset_path)
return dataset


def get_image_path(id: str):
dataset_dir = Path(dataset_path).parent
file_name = dataset.imgs[int(id)]["file_name"]
return str(dataset_dir / file_name)

0 comments on commit 16a953d

Please sign in to comment.