diff --git a/src/nrtk_explorer/app/image_server.py b/src/nrtk_explorer/app/image_server.py new file mode 100644 index 00000000..7bf6cb70 --- /dev/null +++ b/src/nrtk_explorer/app/image_server.py @@ -0,0 +1,55 @@ +from aiohttp import web +from PIL import Image +import io +from trame.app import get_server +from nrtk_explorer.library.dataset import get_image_path + + +ORIGINAL_IMAGE_ENDPOINT = "original-image" + +server = get_server() + + +def is_browser_compatible_image(file_path): + # Check if the image format is compatible with web browsers + compatible_formats = {"jpg", "jpeg", "png", "gif", "webp"} + return file_path.split(".")[-1].lower() in compatible_formats + + +def make_response(image, format): + bytes_io = io.BytesIO() + image.save(bytes_io, format=format) + bytes_io.seek(0) + return web.Response(body=bytes_io.read(), content_type=f"image/{format.lower()}") + + +async def original_image_endpoint(request: web.Request): + id = request.match_info["id"] + image_path = get_image_path(id) + + if image_path in server.context.images_manager.images: + image = server.context.images_manager.images[image_path] + send_format = "PNG" + if is_browser_compatible_image(image.format): + send_format = image.format.upper() + return make_response(image, send_format) + + if is_browser_compatible_image(image_path): + return web.FileResponse(image_path) + else: + image = Image.open(image_path) + return make_response(image, "PNG") + + +image_routes = [ + web.get(f"/{ORIGINAL_IMAGE_ENDPOINT}/{{id}}", original_image_endpoint), +] + + +def app_available(wslink_server): + """Add our custom REST endpoints to the trame server.""" + wslink_server.app.add_routes(image_routes) + + +# --hot-reload does not work if this is configured as decorator on the function +server.controller.add("on_server_bind")(app_available) diff --git a/src/nrtk_explorer/app/trame_utils.py b/src/nrtk_explorer/app/trame_utils.py index b457589d..40b09e1d 100644 --- a/src/nrtk_explorer/app/trame_utils.py +++ b/src/nrtk_explorer/app/trame_utils.py @@ -18,5 +18,6 @@ async def __aenter__(self): async def __aexit__(self, exc_type, exc, tb): self.state.flush() await asyncio.sleep(0) - await asyncio.sleep(0.1) + await asyncio.sleep(0) + await asyncio.sleep(0) await asyncio.sleep(0) diff --git a/src/nrtk_explorer/app/transforms.py b/src/nrtk_explorer/app/transforms.py index 55a0c29c..d3c0da36 100644 --- a/src/nrtk_explorer/app/transforms.py +++ b/src/nrtk_explorer/app/transforms.py @@ -37,7 +37,8 @@ dataset_id_to_image_id, dataset_id_to_transformed_image_id, ) -from nrtk_explorer.library.dataset import get_dataset +from nrtk_explorer.library.dataset import get_dataset, get_image_path +import nrtk_explorer.app.image_server logger = logging.getLogger(__name__) @@ -114,7 +115,6 @@ def __init__(self, server): def on_server_ready(self, *args, **kwargs): # Bind instance methods to state change - # self.state.change("feature_extraction_model")(self.on_feature_extraction_model_change) self.state.change("current_dataset")(self.on_current_dataset_change) self.state.change("current_num_elements")(self.on_current_num_elements_change) @@ -265,20 +265,16 @@ def compute_predictions_source_images(self, ids): ) async def _update_images(self): - loading = len(self.context.selected_dataset_ids) > 0 + selected_ids = self.context.selected_dataset_ids + loading = len(selected_ids) > 0 async with SetStateAsync(self.state): self.state.loading_images = loading self.state.hovered_id = "" - selected_ids = self.context.selected_dataset_ids - current_dir = os.path.dirname(self.state.current_dataset) - for selected_id in selected_ids: - image_metadata = self.context.dataset.imgs[int(selected_id)] - image_filename = os.path.join(current_dir, image_metadata["file_name"]) - img = self.context.images_manager.load_image(image_filename) + filename = get_image_path(selected_id) + img = self.context.images_manager.load_image(filename) image_id = dataset_id_to_image_id(selected_id) - self.state[image_id] = images_manager.convert_to_base64(img) self.context.image_objects[image_id] = img async with SetStateAsync(self.state): @@ -287,10 +283,8 @@ async def _update_images(self): self.state[image_id_to_result_id(id)] = None self.state[image_id_to_result_id(dataset_id_to_image_id(id))] = None self.state[image_id_to_result_id(dataset_id_to_transformed_image_id(id))] = None - - async with SetStateAsync(self.state): self.state.source_image_ids = [dataset_id_to_image_id(id) for id in selected_ids] - self.state.loading_images = False + self.state.loading_images = False # remove big spinner and show table async with SetStateAsync(self.state): self.load_ground_truth_annotations(selected_ids) @@ -352,12 +346,6 @@ def on_current_dataset_change(self, current_dataset, **kwargs): if self.is_standalone_app: self.context.images_manager = images_manager.ImagesManager() - # No GUI, not tested - # def on_feature_extraction_model_change(self, **kwargs): - # logger.debug(f">>> on_feature_extraction_model_change change {self.state}") - # self.delete_computed_image_data() - # self._start_update_images() - def on_image_hovered(self, id): self.state.hovered_id = id diff --git a/src/nrtk_explorer/app/ui/image_list.py b/src/nrtk_explorer/app/ui/image_list.py index 827c8457..5bf2cef5 100644 --- a/src/nrtk_explorer/app/ui/image_list.py +++ b/src/nrtk_explorer/app/ui/image_list.py @@ -55,7 +55,7 @@ def __init__(self, hover_fn=None, **kwargs): ground_truth_to_transformed_detection_score: meta.ground_truth_to_transformed_detection_score.toFixed(2), original_detection_to_transformed_detection_score: meta.original_detection_to_transformed_detection_score.toFixed(2), id: datasetId, - original: id, + original: `original-image/${datasetId}`, transformed: `transformed_${id}`, groundTruthAnnotations: get(`result_${datasetId}`), originalAnnotations: get(`result_${id}`), @@ -66,6 +66,10 @@ def __init__(self, hover_fn=None, **kwargs): ), row_key="id", rows_per_page_options=("[0]",), # [0] means show all rows + virtual_scroll=True, + __properties=[ + ("virtual_scroll", "virtual-scroll"), + ], ): # ImageDetection component for image columns with html.Template( @@ -76,7 +80,7 @@ def __init__(self, hover_fn=None, **kwargs): ImageDetection( style="max-width: 10rem; float: inline-end;", identifier=("props.row.original",), - src=("get(props.row.original)",), + src=("props.row.original",), annotations=("props.row.groundTruthAnnotations",), categories=("annotation_categories",), selected=("(props.row.original == hovered_id)",), @@ -93,7 +97,7 @@ def __init__(self, hover_fn=None, **kwargs): ImageDetection( style="max-width: 10rem; float: inline-end;", identifier=("props.row.original",), - src=("get(props.row.original)",), + src=("props.row.original",), annotations=("props.row.originalAnnotations",), categories=("annotation_categories",), selected=("(props.row.original == hovered_id)",), @@ -132,7 +136,7 @@ def __init__(self, hover_fn=None, **kwargs): ) ImageDetection( identifier=("props.row.original",), - src=("get(props.row.original)",), + src=("props.row.original",), annotations=("props.row.groundTruthAnnotations",), categories=("annotation_categories",), selected=("(props.row.original == hovered_id)",), @@ -145,7 +149,7 @@ def __init__(self, hover_fn=None, **kwargs): ) ImageDetection( identifier=("props.row.original",), - src=("get(props.row.original)",), + src=("props.row.original",), annotations=("props.row.originalAnnotations",), categories=("annotation_categories",), selected=("(props.row.original == hovered_id)",), diff --git a/src/nrtk_explorer/library/dataset.py b/src/nrtk_explorer/library/dataset.py index 5fca34a4..58b92b81 100644 --- a/src/nrtk_explorer/library/dataset.py +++ b/src/nrtk_explorer/library/dataset.py @@ -1,17 +1,24 @@ import kwcoco +from pathlib import Path def load_dataset(path: str): return kwcoco.CocoDataset(path) -dataset_json: kwcoco.CocoDataset = kwcoco.CocoDataset() +dataset: kwcoco.CocoDataset = kwcoco.CocoDataset() dataset_path: str = "" def get_dataset(path: str, force_reload=False): - global dataset_json, dataset_path + global dataset, dataset_path if dataset_path != path or force_reload: dataset_path = path - dataset_json = load_dataset(dataset_path) - return dataset_json + dataset = load_dataset(dataset_path) + return dataset + + +def get_image_path(id: str): + dataset_dir = Path(dataset_path).parent + file_name = dataset.imgs[int(id)]["file_name"] + return str(dataset_dir / file_name)