Skip to content

Commit

Permalink
fix(dataset): set kwcoco optional backend
Browse files Browse the repository at this point in the history
  • Loading branch information
vicentebolea committed Aug 29, 2024
1 parent af34dbb commit 0abe405
Show file tree
Hide file tree
Showing 6 changed files with 105 additions and 21 deletions.
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ classifiers = [

dependencies = [
"accelerate",
"kwcoco",
"nrtk>=0.4.2",
"numpy",
"Pillow",
Expand All @@ -48,6 +47,10 @@ dependencies = [
]

[project.optional-dependencies]
kwcoco= [
"kwcoco",
]

dev = [
"black",
"flake8",
Expand Down
5 changes: 4 additions & 1 deletion src/nrtk_explorer/app/core.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import logging
from typing import Iterable
from pathlib import Path
from functools import partial

from trame.widgets import html
from trame_server.utils.namespace import Translator
from nrtk_explorer.library import images_manager
from nrtk_explorer.library.filtering import FilterProtocol
from nrtk_explorer.library.dataset import get_dataset
from nrtk_explorer.library.dataset import get_dataset, get_image_fpath

from nrtk_explorer.app.embeddings import EmbeddingsApp
from nrtk_explorer.app.transforms import TransformsApp
Expand Down Expand Up @@ -54,6 +55,8 @@ def __init__(self, server=None):
self.input_paths = known_args.dataset
self.state.current_dataset = str(Path(self.input_paths[0]).resolve())

self.ctrl.get_image_fpath = partial(get_image_fpath, path=self.state.current_dataset)

self.context["image_objects"] = {}
self.context["images_manager"] = images_manager.ImagesManager()

Expand Down
3 changes: 1 addition & 2 deletions src/nrtk_explorer/app/image_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from PIL import Image
import io
from trame.app import get_server
from nrtk_explorer.library.dataset import get_image_path


ORIGINAL_IMAGE_ENDPOINT = "original-image"
Expand All @@ -25,7 +24,7 @@ def make_response(image, format):

async def original_image_endpoint(request: web.Request):
id = request.match_info["id"]
image_path = get_image_path(id)
image_path = server.controller.get_image_fpath(int(id))

if image_path in server.context.images_manager.images:
image = server.context.images_manager.images[image_path]
Expand Down
8 changes: 6 additions & 2 deletions src/nrtk_explorer/app/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
dataset_id_to_image_id,
dataset_id_to_transformed_image_id,
)
from nrtk_explorer.library.dataset import get_dataset, get_image_path
from nrtk_explorer.library.dataset import get_dataset
import nrtk_explorer.app.image_server


Expand Down Expand Up @@ -121,6 +121,10 @@ def tranformed_became_visible(old, new):
self.server.controller.add("on_server_ready")(self.on_server_ready)
self._on_hover_fn = None

@property
def get_image_fpath(self):
return self.server.controller.get_image_fpath

def on_server_ready(self, *args, **kwargs):
# Bind instance methods to state change
self.state.change("current_dataset")(self.on_current_dataset_change)
Expand Down Expand Up @@ -295,7 +299,7 @@ async def _update_images(self):
self.state.hovered_id = ""

for selected_id in selected_ids:
filename = get_image_path(selected_id)
filename = self.get_image_fpath(int(selected_id))
img = self.context.images_manager.load_image(filename)
image_id = dataset_id_to_image_id(selected_id)
self.context.image_objects[image_id] = img
Expand Down
65 changes: 50 additions & 15 deletions src/nrtk_explorer/library/dataset.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,59 @@
import kwcoco
"""
Module to load the dataset and get the image file path given an image id.
Example:
dataset = get_dataset("path/to/dataset.json")
image_fpath = dataset.get_image_fpath(image_id)
"""

from functools import lru_cache
from pathlib import Path

import json


class DefaultDataset:
"""Default dataset class to load the dataset and get the image file path given an image id."""

def __init__(self, path: str):
with open(path) as f:
self.data = json.load(f)
self.fpath = path
self.cats = {cat["id"]: cat for cat in self.data["categories"]}
self.anns = {ann["id"]: ann for ann in self.data["annotations"]}
self.imgs = {img["id"]: img for img in self.data["images"]}

def get_image_fpath(self, selected_id: int):
"""Get the image file path given an image id."""
dataset_dir = Path(self.fpath).parent
file_name = self.imgs[selected_id]["file_name"]
return str(dataset_dir / file_name)

def load_dataset(path: str):
return kwcoco.CocoDataset(path)

@lru_cache
def __load_dataset(path: str):
"""Load the dataset given the path to the dataset file."""
try:
import kwcoco

dataset: kwcoco.CocoDataset = kwcoco.CocoDataset()
dataset_path: str = ""
return kwcoco.CocoDataset(path)
except ImportError:
return DefaultDataset(path)


def get_dataset(path: str, force_reload=False):
global dataset, dataset_path
if dataset_path != path or force_reload:
dataset_path = path
dataset = load_dataset(dataset_path)
return dataset
def get_dataset(path: str, force_reload: bool = False):
"""Get the dataset object given the path to the dataset file.
Args:
path (str): Path to the dataset file.
force_reload (bool): Whether to force reload the dataset. Default: False.
Return:
dataset: Dataset object.
"""
if force_reload:
__load_dataset.cache_clear()
return __load_dataset(path)


def get_image_path(id: str):
dataset_dir = Path(dataset_path).parent
file_name = dataset.imgs[int(id)]["file_name"]
return str(dataset_dir / file_name)
def get_image_fpath(selected_id: int, path: str):
"""Get the image file path given an image id."""
return get_dataset(path).get_image_fpath(selected_id)
40 changes: 40 additions & 0 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from nrtk_explorer.library.dataset import get_dataset, DefaultDataset
import nrtk_explorer.test_data

from unittest import mock
from pathlib import Path

import pytest


@pytest.fixture
def dataset_path():
dir_name = Path(nrtk_explorer.test_data.__file__).parent
return f"{dir_name}/coco-od-2017/test_val2017.json"


def test_get_dataset(dataset_path):
ds1 = get_dataset(dataset_path)
assert ds1 is not None

ds1 = get_dataset(dataset_path)
ds2 = get_dataset(dataset_path)
assert ds1 is ds2

ds1 = get_dataset(dataset_path)
ds2 = get_dataset(dataset_path, force_reload=True)
assert ds1 is not ds2


@mock.patch("nrtk_explorer.library.dataset.__load_dataset", lambda path: DefaultDataset(path))
def test_get_dataset_empty():
with pytest.raises(FileNotFoundError):
get_dataset("nonexisting")


def test_DefaultDataset(dataset_path):
ds = DefaultDataset(dataset_path)
assert len(ds.imgs) > 0
assert len(ds.cats) > 0
assert len(ds.anns) > 0
assert Path(ds.get_image_fpath(491497)).name == "000000491497.jpg"

0 comments on commit 0abe405

Please sign in to comment.