Skip to content

Commit

Permalink
feat: use kwcoco for coco datasets
Browse files Browse the repository at this point in the history
  • Loading branch information
vicentebolea committed Jun 28, 2024
1 parent 2237edf commit 58f95b2
Show file tree
Hide file tree
Showing 7 changed files with 27 additions and 51 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ classifiers = [

dependencies = [
"accelerate",
"kwcoco",
"nrtk>=0.4.2",
"numpy",
"Pillow",
Expand Down
15 changes: 4 additions & 11 deletions src/nrtk_explorer/app/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,8 @@

import os

import json
import random


logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

Expand Down Expand Up @@ -130,10 +128,8 @@ def on_server_ready(self, *args, **kwargs):
def on_dataset_change(self, **kwargs):
# Reset cache
self.context.images_manager = images_manager.ImagesManager()

dataset = get_dataset(self.state.current_dataset, force_reload=True)

self.state.num_images_max = len(dataset["images"])
self.context.dataset = get_dataset(self.state.current_dataset)
self.state.num_images_max = len(self.context.dataset.imgs)
self.state.random_sampling_disabled = False
self.state.num_images_disabled = False

Expand Down Expand Up @@ -162,14 +158,11 @@ def on_random_sampling_change(self, **kwargs):
self.reload_images()

def reload_images(self):
with open(self.state.current_dataset) as f:
dataset = json.load(f)

categories = {}
for category in dataset["categories"]:
for category in self.context.dataset.cats.values():
categories[category["id"]] = category

images = dataset["images"]
images = list(self.context.dataset.imgs.values())

selected_images = []
if self.state.num_images:
Expand Down
11 changes: 6 additions & 5 deletions src/nrtk_explorer/app/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
import nrtk_explorer.test_data

import asyncio
import json
import os
import kwcoco

from trame.widgets import quasar, html
from trame.ui.quasar import QLayout
Expand Down Expand Up @@ -60,10 +60,11 @@ def on_feature_extraction_model_change(self, **kwargs):

def on_current_dataset_change(self, **kwargs):
self.state.num_elements_disabled = True
with open(self.state.current_dataset) as f:
dataset = json.load(f)
self.images = dataset["images"]
self.state.num_elements_max = len(self.images)
if self.context.dataset is None:
self.context.dataset = kwcoco.CocoDataset(self.state.current_dataset)

self.images = list(self.context.dataset.imgs.values())
self.state.num_elements_max = len(self.images)
self.state.num_elements_disabled = False

if self.is_standalone_app:
Expand Down
35 changes: 9 additions & 26 deletions src/nrtk_explorer/app/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,11 +146,9 @@ def on_apply_transform(self, *args, **kwargs):
if len(transformed_image_ids) == 0:
return

dataset = get_dataset(self.state.current_dataset)

# Erase current annotations
dataset_ids = [image_id_to_dataset_id(id) for id in self.state.source_image_ids]
for ann in dataset["annotations"]:
for ann in self.context.dataset.anns.values():
if str(ann["image_id"]) in dataset_ids:
transformed_id = f"transformed_img_{ann['image_id']}"
if transformed_id in self.context["annotations"]:
Expand Down Expand Up @@ -221,8 +219,7 @@ def compute_annotations(self, ids):
return predictions

def on_current_num_elements_change(self, current_num_elements, **kwargs):
dataset = get_dataset(self.state.current_dataset)
ids = [img["id"] for img in dataset["images"]]
ids = [img["id"] for img in self.context.dataset.imgs.values()]
return self.set_source_images(ids[:current_num_elements])

def compute_predictions_source_images(self, old_ids, ids):
Expand All @@ -236,22 +233,18 @@ def compute_predictions_source_images(self, old_ids, ids):
if len(ids) == 0:
return

dataset = get_dataset(self.state.current_dataset)

annotations = self.compute_annotations(ids)
self.predictions_source_images = convert_from_predictions_to_first_arg(
annotations,
dataset,
self.context.dataset,
ids,
)

# load ground truth annotations
dataset_annotations = dataset["annotations"]
# collect annotations for each dataset_id
annotations = {
dataset_id: [
annotation
for annotation in dataset_annotations
for annotation in self.context.dataset.anns.values()
if str(annotation["image_id"]) == dataset_id
]
for dataset_id in dataset_ids
Expand All @@ -263,7 +256,7 @@ def compute_predictions_source_images(self, old_ids, ids):

ground_truth_annotations = annotations.values()
ground_truth_predictions = convert_from_ground_truth_to_second_arg(
ground_truth_annotations, dataset
ground_truth_annotations, self.context.dataset
)
scores = compute_score(
dataset_ids,
Expand All @@ -280,14 +273,8 @@ def _update_images(self, selected_ids):

current_dir = os.path.dirname(self.state.current_dataset)

dataset = get_dataset(self.state.current_dataset)

for selected_id in selected_ids:
image_index = self.context.image_id_to_index[selected_id]
if image_index >= len(dataset["images"]):
continue

image_metadata = dataset["images"][image_index]
image_metadata = self.context.dataset.imgs[selected_id]
image_id = f"img_{image_metadata['id']}"
source_image_ids.append(image_id)
image_filename = os.path.join(current_dir, image_metadata["file_name"])
Expand Down Expand Up @@ -342,21 +329,17 @@ def reset_data(self):

def on_current_dataset_change(self, current_dataset, **kwargs):
logger.debug(f"on_current_dataset_change change {self.state}")

self.reset_data()

dataset = get_dataset(current_dataset)
categories = {}
if self.context.dataset is None:
self.context.dataset = get_dataset(current_dataset)

for category in dataset["categories"]:
for category in self.context.dataset.cats.values():
categories[category["id"]] = category

self.state.annotation_categories = categories

self.context.image_id_to_index = {}
for i, image in enumerate(dataset["images"]):
self.context.image_id_to_index[image["id"]] = i

if self.is_standalone_app:
self.context.images_manager = images_manager.ImagesManager()

Expand Down
4 changes: 2 additions & 2 deletions src/nrtk_explorer/library/coco_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def convert_from_ground_truth_to_first_arg(dataset_annotations):

def convert_from_ground_truth_to_second_arg(dataset_annotations, dataset):
"""Convert ground truth annotations to COCOScorer format"""
categories = {cat["id"]: cat["name"] for cat in dataset["categories"]}
categories = {cat["id"]: cat["name"] for cat in dataset.cats.values()}
annotations = list()
for dataset_image_annotations in dataset_annotations:
image_annotations = list()
Expand All @@ -59,7 +59,7 @@ def convert_from_ground_truth_to_second_arg(dataset_annotations, dataset):
def convert_from_predictions_to_first_arg(predictions, dataset, ids):
"""Convert predictions to COCOScorer format"""
predictions = convert_from_predictions_to_second_arg(predictions)
categories = {cat["name"]: cat["id"] for cat in dataset["categories"]}
categories = {cat["name"]: cat["id"] for cat in dataset.cats.values()}
real_ids = [id_.split("_")[-1] for id_ in ids]

for id_, img_predictions in zip(real_ids, predictions):
Expand Down
5 changes: 2 additions & 3 deletions src/nrtk_explorer/library/dataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import TypedDict, List
import json
import kwcoco


class DatasetCategory(TypedDict):
Expand Down Expand Up @@ -28,8 +28,7 @@ class Dataset(TypedDict):


def load_dataset(path: str) -> Dataset:
with open(path) as f:
return json.load(f)
return kwcoco.CocoDataset(path)


dataset_json: Dataset = {"categories": [], "images": [], "annotations": []}
Expand Down
7 changes: 3 additions & 4 deletions tests/test_embeddings.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from nrtk_explorer.library import embeddings_extractor
from nrtk_explorer.library import dimension_reducers
from nrtk_explorer.library import images_manager
from nrtk_explorer.library.dataset import get_dataset
import nrtk_explorer.test_data

from tabulate import tabulate
from itertools import product
from pathlib import Path

import json
import os
import pytest
import timeit
Expand All @@ -17,9 +17,8 @@


def image_paths_impl(file_name):
with open(file_name) as f:
dataset = json.load(f)
images = dataset["images"]
dataset = get_dataset(file_name)
images = dataset.imgs.values()

paths = list()
for image_metadata in images:
Expand Down

0 comments on commit 58f95b2

Please sign in to comment.