-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #101 from Kitware/main
Bring master changes
- Loading branch information
Showing
10 changed files
with
142 additions
and
66 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -66,3 +66,25 @@ jobs: | |
pip install -e '.[dev]' | ||
- name: Invoke PyTest | ||
run: pytest -v . | ||
|
||
semantic_release: | ||
runs-on: ubuntu-latest | ||
name: Semantic release noop | ||
steps: | ||
- uses: actions/checkout@v3 | ||
with: | ||
fetch-depth: 0 | ||
- name: Set up Python 3.10 | ||
uses: actions/setup-python@v5 | ||
with: | ||
python-version: "3.10" | ||
|
||
- name: Install dependencies | ||
run: | | ||
python -m pip install --upgrade pip setuptools wheel "build<0.10.0" python-semantic-release | ||
- name: Python Semantic Release | ||
id: release | ||
uses: python-semantic-release/[email protected] | ||
with: | ||
root_options: -vv --noop |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
name: nrtk-explorer-cuda | ||
channels: | ||
- nvidia | ||
- pytorch | ||
- conda-forge | ||
dependencies: | ||
- pytorch::pytorch | ||
- pytorch::torchvision | ||
- pytorch::torchaudio | ||
- pytorch::pytorch-cuda=11.8 | ||
- nrtk-explorer | ||
variables: | ||
channel_priority: flexible |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,24 +1,59 @@ | ||
import kwcoco | ||
""" | ||
Module to load the dataset and get the image file path given an image id. | ||
Example: | ||
dataset = get_dataset("path/to/dataset.json") | ||
image_fpath = dataset.get_image_fpath(image_id) | ||
""" | ||
|
||
from functools import lru_cache | ||
from pathlib import Path | ||
|
||
import json | ||
|
||
|
||
class DefaultDataset: | ||
"""Default dataset class to load the dataset and get the image file path given an image id.""" | ||
|
||
def __init__(self, path: str): | ||
with open(path) as f: | ||
self.data = json.load(f) | ||
self.fpath = path | ||
self.cats = {cat["id"]: cat for cat in self.data["categories"]} | ||
self.anns = {ann["id"]: ann for ann in self.data["annotations"]} | ||
self.imgs = {img["id"]: img for img in self.data["images"]} | ||
|
||
def get_image_fpath(self, selected_id: int): | ||
"""Get the image file path given an image id.""" | ||
dataset_dir = Path(self.fpath).parent | ||
file_name = self.imgs[selected_id]["file_name"] | ||
return str(dataset_dir / file_name) | ||
|
||
def load_dataset(path: str): | ||
return kwcoco.CocoDataset(path) | ||
|
||
@lru_cache | ||
def __load_dataset(path: str): | ||
"""Load the dataset given the path to the dataset file.""" | ||
try: | ||
import kwcoco | ||
|
||
dataset: kwcoco.CocoDataset = kwcoco.CocoDataset() | ||
dataset_path: str = "" | ||
return kwcoco.CocoDataset(path) | ||
except ImportError: | ||
return DefaultDataset(path) | ||
|
||
|
||
def get_dataset(path: str, force_reload=False): | ||
global dataset, dataset_path | ||
if dataset_path != path or force_reload: | ||
dataset_path = path | ||
dataset = load_dataset(dataset_path) | ||
return dataset | ||
def get_dataset(path: str, force_reload: bool = False): | ||
"""Get the dataset object given the path to the dataset file. | ||
Args: | ||
path (str): Path to the dataset file. | ||
force_reload (bool): Whether to force reload the dataset. Default: False. | ||
Return: | ||
dataset: Dataset object. | ||
""" | ||
if force_reload: | ||
__load_dataset.cache_clear() | ||
return __load_dataset(path) | ||
|
||
|
||
def get_image_path(id: str): | ||
dataset_dir = Path(dataset_path).parent | ||
file_name = dataset.imgs[int(id)]["file_name"] | ||
return str(dataset_dir / file_name) | ||
def get_image_fpath(selected_id: int, path: str): | ||
"""Get the image file path given an image id.""" | ||
return get_dataset(path).get_image_fpath(selected_id) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
from nrtk_explorer.library.dataset import get_dataset, DefaultDataset | ||
import nrtk_explorer.test_data | ||
|
||
from unittest import mock | ||
from pathlib import Path | ||
|
||
import pytest | ||
|
||
|
||
@pytest.fixture | ||
def dataset_path(): | ||
dir_name = Path(nrtk_explorer.test_data.__file__).parent | ||
return f"{dir_name}/coco-od-2017/test_val2017.json" | ||
|
||
|
||
def test_get_dataset(dataset_path): | ||
ds1 = get_dataset(dataset_path) | ||
assert ds1 is not None | ||
|
||
ds1 = get_dataset(dataset_path) | ||
ds2 = get_dataset(dataset_path) | ||
assert ds1 is ds2 | ||
|
||
ds1 = get_dataset(dataset_path) | ||
ds2 = get_dataset(dataset_path, force_reload=True) | ||
assert ds1 is not ds2 | ||
|
||
|
||
@mock.patch("nrtk_explorer.library.dataset.__load_dataset", lambda path: DefaultDataset(path)) | ||
def test_get_dataset_empty(): | ||
with pytest.raises(FileNotFoundError): | ||
get_dataset("nonexisting") | ||
|
||
|
||
def test_DefaultDataset(dataset_path): | ||
ds = DefaultDataset(dataset_path) | ||
assert len(ds.imgs) > 0 | ||
assert len(ds.cats) > 0 | ||
assert len(ds.anns) > 0 | ||
assert Path(ds.get_image_fpath(491497)).name == "000000491497.jpg" |