Skip to content

Commit

Permalink
Sub labels hints for yolov8 pose (#53)
Browse files Browse the repository at this point in the history
* more yolo v8 tests

* passing skeleton sub labels to yolo8 extractor

* reducing cognitive complexity of YOLOv8PoseExtractor._load_categories

* relaxing orthogonality restraint to ~1 degree

* allowing sub label hint for yolo8 pose to have more labels than in dataset

* renaming tests

* allowing for non-sequential keys in names dict

* black fixes

* added type annotations

* remove non reachable raise
  • Loading branch information
Eldies authored Aug 5, 2024
1 parent 04832fb commit d3f0551
Show file tree
Hide file tree
Showing 4 changed files with 270 additions and 78 deletions.
175 changes: 110 additions & 65 deletions datumaro/plugins/yolo_format/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
InvalidAnnotationError,
UndeclaredLabelError,
)
from datumaro.components.extractor import DatasetItem, Extractor, SourceExtractor
from datumaro.components.extractor import CategoriesInfo, DatasetItem, Extractor, SourceExtractor
from datumaro.components.media import Image
from datumaro.util.image import (
DEFAULT_IMAGE_META_FILE_NAME,
Expand Down Expand Up @@ -250,6 +250,12 @@ def _parse_annotations(

return annotations

def _map_label_id(self, label_id: str) -> int:
label_id = self._parse_field(label_id, int, "bbox label id")
if label_id not in self._categories[AnnotationType.label]:
raise UndeclaredLabelError(str(label_id))
return label_id

def _load_one_annotation(
self, parts: List[str], image_height: int, image_width: int
) -> Annotation:
Expand All @@ -260,9 +266,7 @@ def _load_one_annotation(
)
label_id, xc, yc, w, h = parts

label_id = self._parse_field(label_id, int, "bbox label id")
if label_id not in self._categories[AnnotationType.label]:
raise UndeclaredLabelError(str(label_id))
label_id = self._map_label_id(label_id)

w = self._parse_field(w, float, "bbox width")
h = self._parse_field(h, float, "bbox height")
Expand All @@ -277,7 +281,7 @@ def _load_one_annotation(
label=label_id,
)

def _load_categories(self) -> Dict[AnnotationType, LabelCategories]:
def _load_categories(self) -> CategoriesInfo:
names_path = self._config.get("names")
if not names_path:
raise InvalidAnnotationError(f"Failed to parse names file path from config")
Expand Down Expand Up @@ -334,25 +338,49 @@ def _config(self) -> Dict[str, Any]:
except yaml.YAMLError:
raise InvalidAnnotationError("Failed to parse config file")

def _load_categories(self) -> Dict[AnnotationType, LabelCategories]:
@cached_property
def _label_mapping(self) -> Dict[int, int]:
names = self._config["names"]
if isinstance(names, list):
return {index: index for index in range(len(names))}
if isinstance(names, dict):
return {names_key: index for index, names_key in enumerate(sorted(names.keys()))}
raise InvalidAnnotationError("Failed to parse names from config")

def _map_label_id(self, ann_label_id: str) -> int:
names = self._config["names"]
ann_label_id = self._parse_field(ann_label_id, int, "label id")
if isinstance(names, list):
if ann_label_id < 0 or ann_label_id >= len(names):
raise UndeclaredLabelError(str(ann_label_id))
return ann_label_id

if isinstance(names, dict):
if ann_label_id not in names:
raise UndeclaredLabelError(str(ann_label_id))
return self._label_mapping[ann_label_id]

def _load_names_from_config_file(self) -> list:
names = self._config["names"]
if isinstance(names, dict):
names_with_mapped_keys = {
self._label_mapping[names_key]: names[names_key] for names_key in names
}
return [names_with_mapped_keys[i] for i in range(len(names))]
elif isinstance(names, list):
return names
raise InvalidAnnotationError("Failed to parse names from config")

def _load_categories(self) -> CategoriesInfo:
if has_meta_file(self._path):
return {
AnnotationType.label: LabelCategories.from_iterable(
parse_meta_file(self._path).keys()
)
}

if (names := self._config.get("names")) is not None:
if isinstance(names, dict):
return {
AnnotationType.label: LabelCategories.from_iterable(
[names[i] for i in range(len(names))]
)
}
if isinstance(names, list):
return {AnnotationType.label: LabelCategories.from_iterable(names)}

raise InvalidAnnotationError(f"Failed to parse names from config")
names = self._load_names_from_config_file()
return {AnnotationType.label: LabelCategories.from_iterable(names)}

def _get_labels_path_from_image_path(self, image_path: str) -> str:
relative_image_path = osp.relpath(
Expand Down Expand Up @@ -400,10 +428,7 @@ class YOLOv8SegmentationExtractor(YOLOv8Extractor):
def _load_segmentation_annotation(
self, parts: List[str], image_height: int, image_width: int
) -> Polygon:
label_id = self._parse_field(parts[0], int, "Polygon label id")
if label_id not in self._categories[AnnotationType.label]:
raise UndeclaredLabelError(str(label_id))

label_id = self._map_label_id(parts[0])
points = [
self._parse_field(
value, float, f"polygon point {idx // 2} {'x' if idx % 2 == 0 else 'y'}"
Expand All @@ -427,17 +452,28 @@ def _load_one_annotation(


class YOLOv8OrientedBoxesExtractor(YOLOv8Extractor):
@staticmethod
def _check_is_rectangle(p1, p2, p3, p4):
RECTANGLE_ANGLE_PRECISION = math.pi * 1 / 180

@classmethod
def _check_is_rectangle(
cls, p1: Tuple[int, int], p2: Tuple[int, int], p3: Tuple[int, int], p4: Tuple[int, int]
) -> None:
p12_angle = math.atan2(p2[0] - p1[0], p2[1] - p1[1])
p23_angle = math.atan2(p3[0] - p2[0], p3[1] - p2[1])
p43_angle = math.atan2(p3[0] - p4[0], p3[1] - p4[1])
p14_angle = math.atan2(p4[0] - p1[0], p4[1] - p1[1])

if abs(p12_angle - p43_angle) > 0.001 or abs(p23_angle - p14_angle) > 0.001:
if (
abs(p12_angle - p43_angle) > 0.001
or abs(p23_angle - p14_angle) > cls.RECTANGLE_ANGLE_PRECISION
):
raise InvalidAnnotationError(
"Given points do not form a rectangle: opposite sides have different slope angles."
)
if abs((p12_angle - p23_angle) % math.pi - math.pi / 2) > cls.RECTANGLE_ANGLE_PRECISION:
raise InvalidAnnotationError(
"Given points do not form a rectangle: adjacent sides are not orthogonal."
)

def _load_one_annotation(
self, parts: List[str], image_height: int, image_width: int
Expand All @@ -447,9 +483,7 @@ def _load_one_annotation(
f"Unexpected field count {len(parts)} in the bbox description. "
"Expected 9 fields (label, x1, y1, x2, y2, x3, y3, x4, y4)."
)
label_id = self._parse_field(parts[0], int, "bbox label id")
if label_id not in self._categories[AnnotationType.label]:
raise UndeclaredLabelError(str(label_id))
label_id = self._map_label_id(parts[0])
points = [
(
self._parse_field(x, float, f"bbox point {idx} x") * image_width,
Expand Down Expand Up @@ -481,8 +515,17 @@ def _load_one_annotation(


class YOLOv8PoseExtractor(YOLOv8Extractor):
def __init__(
self,
*args,
skeleton_sub_labels: Optional[Dict[str, List[str]]] = None,
**kwargs,
) -> None:
self._skeleton_sub_labels = skeleton_sub_labels
super().__init__(*args, **kwargs)

@cached_property
def _kpt_shape(self):
def _kpt_shape(self) -> list[int]:
if YOLOv8PoseFormat.KPT_SHAPE_FIELD_NAME not in self._config:
raise InvalidAnnotationError(
f"Failed to parse {YOLOv8PoseFormat.KPT_SHAPE_FIELD_NAME} from config"
Expand All @@ -506,61 +549,64 @@ def _kpt_shape(self):
return kpt_shape

@cached_property
def _skeleton_id_to_label_id(self):
def _skeleton_id_to_label_id(self) -> Dict[int, int]:
point_categories = self._categories.get(
AnnotationType.points, PointsCategories.from_iterable([])
)
return {index: label_id for index, label_id in enumerate(sorted(point_categories.items))}

def _load_categories(self) -> Dict[AnnotationType, LabelCategories]:
def _load_categories_from_meta_file(self) -> CategoriesInfo:
dataset_meta = parse_json_file(get_meta_file(self._path))
point_categories = PointsCategories.from_iterable(dataset_meta.get("point_categories", []))
categories = {
AnnotationType.label: LabelCategories.from_iterable(dataset_meta["label_categories"])
}
if len(point_categories) > 0:
categories[AnnotationType.points] = point_categories
return categories

def _load_categories(self) -> CategoriesInfo:
if "names" not in self._config:
raise InvalidAnnotationError(f"Failed to parse names from config")

if has_meta_file(self._path):
dataset_meta = parse_json_file(get_meta_file(self._path))
point_categories = PointsCategories.from_iterable(
dataset_meta.get("point_categories", [])
)
categories = {
AnnotationType.label: LabelCategories.from_iterable(
dataset_meta["label_categories"]
)
}
if len(point_categories) > 0:
categories[AnnotationType.points] = point_categories
return categories
return self._load_categories_from_meta_file()

number_of_points, _ = self._kpt_shape
names = self._config["names"]
if isinstance(names, dict):
if set(names.keys()) != set(range(len(names))):
skeleton_labels = self._load_names_from_config_file()

if self._skeleton_sub_labels:
if missing_labels := set(skeleton_labels) - set(self._skeleton_sub_labels):
raise InvalidAnnotationError(
f"Failed to parse names from config: non-sequential label ids"
f"Labels from config file are absent in sub label hint: {missing_labels}"
)

if skeletons_with_wrong_sub_labels := [
skeleton
for skeleton in skeleton_labels
if len(self._skeleton_sub_labels[skeleton]) != number_of_points
]:
raise InvalidAnnotationError(
f"Number of points in skeletons according to config file is {number_of_points}. "
f"Following skeletons have number of sub labels which differs: {skeletons_with_wrong_sub_labels}"
)
skeleton_labels = [names[i] for i in range(len(names))]
elif isinstance(names, list):
skeleton_labels = names
else:
raise InvalidAnnotationError(f"Failed to parse names from config")

def make_children_names(skeleton_label):
return [
children_labels = self._skeleton_sub_labels or {
skeleton_label: [
f"{skeleton_label}_point_{point_index}" for point_index in range(number_of_points)
]
for skeleton_label in skeleton_labels
}

point_labels = [
(child_name, skeleton_label)
for skeleton_label in skeleton_labels
for child_name in make_children_names(skeleton_label)
for child_name in children_labels[skeleton_label]
]

point_categories = PointsCategories.from_iterable(
[
(
index,
make_children_names(skeleton_label),
set(),
)
(index, children_labels[skeleton_label], set())
for index, skeleton_label in enumerate(skeleton_labels)
]
)
Expand All @@ -572,6 +618,10 @@ def make_children_names(skeleton_label):

return categories

def _map_label_id(self, ann_label_id: str) -> int:
skeleton_id = super()._map_label_id(ann_label_id)
return self._skeleton_id_to_label_id[skeleton_id]

def _load_one_annotation(
self, parts: List[str], image_height: int, image_width: int
) -> Annotation:
Expand All @@ -583,12 +633,7 @@ def _load_one_annotation(
f"and then {values_per_point} for each of {number_of_points} points"
)

skeleton_id = self._parse_field(parts[0], int, "skeleton label id")
label_id = self._skeleton_id_to_label_id.get(skeleton_id, -1)
if label_id not in self._categories[AnnotationType.label]:
raise UndeclaredLabelError(str(skeleton_id))
if self._categories[AnnotationType.label][label_id].parent != "":
raise InvalidAnnotationError("WTF")
label_id = self._map_label_id(parts[0])

point_labels = self._categories[AnnotationType.points][label_id].labels
point_label_ids = [
Expand Down
2 changes: 1 addition & 1 deletion site/content/en/docs/formats/yolo_v8.md
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ A polygon can have three or more points:
# labels/train/image1.txt:
# <label_index> <x1> <y1> <x2> <y2> <x3> <y3> <x4> <y4> <x5> <y5> ...
0 0.146731 0.151795 0.319936 0.301795 0.186603 0.648205
3 0.557735 0.090192 0.357735 0.609808 0.242265 0.509808 0.442265 -0.009808 0.400000 0.266667
3 0.557735 0.090192 0.357735 0.609808 0.242265 0.509808 0.442265 0.009808 0.400000 0.266667
...
```

Expand Down
Loading

0 comments on commit d3f0551

Please sign in to comment.