-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #85 from nisyad-ms/nisyad/ic_od_to_kvp_adapter
Add IC OD to KVP Format Converter
- Loading branch information
Showing
9 changed files
with
358 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
42 changes: 42 additions & 0 deletions
42
tests/test_ic_od_to_kvp_wrapper/test_classification_as_kvp.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
import unittest | ||
|
||
from tests.test_fixtures import MultilcassClassificationTestFixtures | ||
from vision_datasets.common import DatasetTypes | ||
from vision_datasets.image_classification import ClassificationAsKeyValuePairDataset | ||
from vision_datasets.key_value_pair.manifest import KeyValuePairLabelManifest | ||
|
||
|
||
class TestClassificationAsKeyValuePairDataset(unittest.TestCase): | ||
def test_multiclass_classification(self): | ||
sample_classification_dataset, _ = MultilcassClassificationTestFixtures.create_an_ic_dataset() | ||
kvp_dataset = ClassificationAsKeyValuePairDataset(sample_classification_dataset) | ||
|
||
self.assertIsInstance(kvp_dataset, ClassificationAsKeyValuePairDataset) | ||
self.assertEqual(kvp_dataset.dataset_info.type, DatasetTypes.KEY_VALUE_PAIR) | ||
self.assertIn("name", kvp_dataset.dataset_info.schema) | ||
self.assertIn("description", kvp_dataset.dataset_info.schema) | ||
self.assertIn("fieldSchema", kvp_dataset.dataset_info.schema) | ||
|
||
print(kvp_dataset.dataset_info.schema["fieldSchema"]) | ||
|
||
self.assertEqual(kvp_dataset.dataset_info.schema["fieldSchema"], | ||
{"className": { | ||
"type": "string", | ||
"description": "Class name that the image belongs to.", | ||
"classes": { | ||
"1-class": {"description": "A single class name. Only output 1-class as the class name if present."}, | ||
"2-class": {"description": "A single class name. Only output 2-class as the class name if present."}, | ||
"3-class": {"description": "A single class name. Only output 3-class as the class name if present."}, | ||
} | ||
} | ||
}) | ||
|
||
_, target, _ = kvp_dataset[0] | ||
self.assertIsInstance(target, KeyValuePairLabelManifest) | ||
self.assertEqual(target.label_data, | ||
{"fields": {"className": {"value": "1-class"}}} | ||
) | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
import unittest | ||
|
||
from tests.test_fixtures import DetectionTestFixtures | ||
from vision_datasets.common.constants import DatasetTypes | ||
from vision_datasets.image_object_detection import DetectionAsKeyValuePairDataset | ||
from vision_datasets.key_value_pair.manifest import KeyValuePairLabelManifest | ||
|
||
|
||
class TestDetectionAsKeyValuePairDataset(unittest.TestCase): | ||
def test_detection_to_kvp(self): | ||
sample_detection_dataset, _ = DetectionTestFixtures.create_an_od_dataset() | ||
kvp_dataset = DetectionAsKeyValuePairDataset(sample_detection_dataset) | ||
|
||
self.assertIsInstance(kvp_dataset, DetectionAsKeyValuePairDataset) | ||
self.assertEqual(kvp_dataset.dataset_info.type, DatasetTypes.KEY_VALUE_PAIR) | ||
self.assertIn("name", kvp_dataset.dataset_info.schema) | ||
self.assertIn("description", kvp_dataset.dataset_info.schema) | ||
self.assertIn("fieldSchema", kvp_dataset.dataset_info.schema) | ||
|
||
self.assertEqual(kvp_dataset.dataset_info.schema["fieldSchema"], | ||
{'detectedObjects': {'type': 'array', 'description': 'Objects in the image of the specified classes, with bounding boxes', | ||
'items': {'type': 'string', 'description': 'Class name of the object', | ||
'classes': {'1-class': {}, | ||
'2-class': {}, | ||
'3-class': {}, | ||
'4-class': {}}, | ||
'includeGrounding': True}}}) | ||
|
||
_, target, _ = kvp_dataset[0] | ||
self.assertIsInstance(target, KeyValuePairLabelManifest) | ||
self.assertEqual(target.label_data, | ||
{'fields': {'detectedObjects': {'value': [{'value': '1-class', 'groundings': [[0, 0, 100, 100]]}, | ||
{'value': '2-class', 'groundings': [[10, 10, 50, 100]]}]}} | ||
}) | ||
|
||
def test_single_class_description(self): | ||
sample_detection_dataset, _ = DetectionTestFixtures.create_an_od_dataset(n_categories=1) | ||
kvp_dataset = DetectionAsKeyValuePairDataset(sample_detection_dataset) | ||
|
||
self.assertEqual(kvp_dataset.dataset_info.schema["fieldSchema"]['detectedObjects']['items']['classes'], | ||
{'1-class': {"description": "Always output 1-class as the class."}}) | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,9 @@ | ||
from .coco_manifest_adaptor import MultiClassClassificationCocoManifestAdaptor, MultiLabelClassificationCocoManifestAdaptor | ||
from .operations import ImageClassificationCocoDictGenerator | ||
from .manifest import ImageClassificationLabelManifest | ||
from .classification_as_kvp_dataset import ClassificationAsKeyValuePairDataset | ||
|
||
__all__ = ['MultiClassClassificationCocoManifestAdaptor', 'MultiLabelClassificationCocoManifestAdaptor', | ||
'ImageClassificationCocoDictGenerator', | ||
'ImageClassificationLabelManifest'] | ||
'ImageClassificationLabelManifest', | ||
'ClassificationAsKeyValuePairDataset'] |
89 changes: 89 additions & 0 deletions
89
vision_datasets/image_classification/classification_as_kvp_dataset.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
import logging | ||
import typing | ||
from copy import deepcopy | ||
|
||
from vision_datasets.common import DatasetTypes, KeyValuePairDatasetInfo, VisionDataset | ||
from vision_datasets.key_value_pair import ( | ||
KeyValuePairDatasetManifest, | ||
KeyValuePairLabelManifest, | ||
) | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
CLASS_NAME_KEY = "className" | ||
BASE_CLASSIFICATION_SCHEMA = { | ||
"name": "Multiclass image classification", | ||
"description": "Classify images into one of the provided classes.", | ||
"fieldSchema": { | ||
f"{CLASS_NAME_KEY}": { | ||
"type": "string", | ||
"description": "Class name that the image belongs to.", | ||
"classes": {} | ||
} | ||
} | ||
} | ||
|
||
|
||
class ClassificationAsKeyValuePairDataset(VisionDataset): | ||
"""Dataset class that access Classification datset as KeyValuePair dataset.""" | ||
|
||
def __init__(self, classification_dataset: VisionDataset): | ||
""" | ||
Initializes an instance of the ClassificationAsKeyValuePairDataset class. | ||
Args: | ||
classification_dataset (VisionDataset): The classification dataset to convert to key-value pair dataset. | ||
""" | ||
|
||
if classification_dataset is None or classification_dataset.dataset_info.type not in {DatasetTypes.IMAGE_CLASSIFICATION_MULTICLASS}: | ||
# TODO: Add support for multilabel classification | ||
raise ValueError | ||
|
||
# Generate schema and update dataset info | ||
classification_dataset = deepcopy(classification_dataset) | ||
|
||
dataset_info_dict = classification_dataset.dataset_info.__dict__ | ||
dataset_info_dict["type"] = DatasetTypes.KEY_VALUE_PAIR.name.lower() | ||
self.class_names = [c.name for c in classification_dataset.dataset_manifest.categories] | ||
self.class_id_to_names = {c.id: c.name for c in classification_dataset.dataset_manifest.categories} | ||
self.img_id_to_pos = {x.id: i for i, x in enumerate(classification_dataset.dataset_manifest.images)} | ||
|
||
schema = self.construct_schema(self.class_names) | ||
# Update dataset_info with schema | ||
dataset_info = KeyValuePairDatasetInfo({**dataset_info_dict, "schema": schema}) | ||
|
||
# Construct KeyValuePairDatasetManifest | ||
annotations = [] | ||
for id, img in enumerate(classification_dataset.dataset_manifest.images, 1): | ||
label_id = img.labels[0].label_data | ||
label_name = self.class_id_to_names[label_id] | ||
|
||
kvp_label_data = self.construct_kvp_label_data(label_name) | ||
img_ids = [self.img_id_to_pos[img.id]] # 0-based index | ||
kvp_annotation = KeyValuePairLabelManifest(id, img_ids, label_data=kvp_label_data) | ||
|
||
# KVPDatasetManifest expects img.labels to be empty. Labels are instead stored in KVP annotation | ||
img.labels = [] | ||
annotations.append(kvp_annotation) | ||
|
||
dataset_manifest = KeyValuePairDatasetManifest(classification_dataset.dataset_manifest.images, annotations, schema, additional_info=classification_dataset.dataset_manifest.additional_info) | ||
super().__init__(dataset_info, dataset_manifest, dataset_resources=classification_dataset.dataset_resources) | ||
|
||
def construct_schema(self, class_names: typing.List[str]) -> typing.Dict[str, typing.Any]: | ||
schema: typing.Dict[str, typing.Any] = BASE_CLASSIFICATION_SCHEMA # initialize with base schema | ||
schema["fieldSchema"][f"{CLASS_NAME_KEY}"]["classes"] = {c: {"description": f"A single class name. Only output {c} as the class name if present."} for c in class_names} | ||
return schema | ||
|
||
def construct_kvp_label_data(self, label_name: str) -> typing.Dict[str, typing.Union[typing.Dict[str, typing.Dict[str, str]], None]]: | ||
""" | ||
Convert the classification dataset label_name to the desired format for KVP annnotation as defined by the BASE_CLASSIFICATION_SCHEMA. | ||
E.g. {"fields": {"className": {"value": <label_name>}}} | ||
""" | ||
return { | ||
f"{KeyValuePairLabelManifest.LABEL_KEY}": { | ||
f"{CLASS_NAME_KEY}": { | ||
f"{KeyValuePairLabelManifest.LABEL_VALUE_KEY}": label_name | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.