Skip to content

Commit

Permalink
feat: add support of rendering segmentation image (#22)
Browse files Browse the repository at this point in the history
Signed-off-by: ktro2828 <[email protected]>
  • Loading branch information
ktro2828 authored Oct 23, 2024
1 parent 1d8da91 commit e6dc14d
Show file tree
Hide file tree
Showing 13 changed files with 444 additions and 53 deletions.
46 changes: 44 additions & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ rerun-sdk = "0.17.0"
pyquaternion = "^0.9.9"
matplotlib = "^3.9.2"
shapely = "<2.0.0"
pycocotools = "^2.0.8"

[tool.poetry.group.dev.dependencies]
pytest = "^8.2.2"
Expand Down
57 changes: 51 additions & 6 deletions t4_devkit/schema/tables/object_ann.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,51 @@
from __future__ import annotations

import base64
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any

from pycocotools import mask as cocomask
from typing_extensions import Self

from ..name import SchemaName
from .base import SchemaBase
from .registry import SCHEMAS
from ..name import SchemaName

if TYPE_CHECKING:
from t4_devkit.typing import MaskType, RoiType
from t4_devkit.typing import NDArrayU8, RoiType

__all__ = ("ObjectAnn",)
__all__ = ("ObjectAnn", "RLEMask")


@dataclass
class RLEMask:
"""A dataclass to represent segmentation mask compressed by RLE.
Attributes:
size (list[int, int]): Size of image ordering (width, height).
counts (str): RLE compressed mask data.
"""

size: list[int, int]
counts: str

@property
def width(self) -> int:
return self.size[0]

@property
def height(self) -> int:
return self.size[1]

def decode(self) -> NDArrayU8:
"""Decode segmentation mask.
Returns:
Decoded mask in shape of (H, W).
"""
counts = base64.b64decode(self.counts)
data = {"counts": counts, "size": self.size}
return cocomask.decode(data).T


@dataclass
Expand All @@ -27,7 +60,7 @@ class ObjectAnn(SchemaBase):
category_token (str): Foreign key pointing to the object category.
attribute_tokens (list[str]): Foreign keys. List of attributes for this annotation.
bbox (RoiType): Annotated bounding box. Given as [xmin, ymin, xmax, ymax].
mask (MaskType): Instance mask using the COCO format.
mask (RLEMask): Instance mask using the COCO format compressed by RLE.
"""

token: str
Expand All @@ -36,7 +69,7 @@ class ObjectAnn(SchemaBase):
category_token: str
attribute_tokens: list[str]
bbox: RoiType
mask: MaskType
mask: RLEMask

# shortcuts
category_name: str = field(init=False)
Expand All @@ -47,12 +80,24 @@ def shortcuts() -> tuple[str]:

@classmethod
def from_dict(cls, data: dict[str, Any]) -> Self:
return cls(**data)
new_data = data.copy()
new_data["mask"] = RLEMask(**data["mask"])
return cls(**new_data)

@property
def width(self) -> int:
"""Return the width of the bounding box.
Returns:
Bounding box width in pixel.
"""
return self.bbox[2] - self.bbox[0]

@property
def height(self) -> int:
"""Return the height of the bounding box.
Returns:
Bounding box height in pixel.
"""
return self.bbox[3] - self.bbox[1]
11 changes: 7 additions & 4 deletions t4_devkit/schema/tables/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@

from typing_extensions import Self

from ..name import SchemaName
from .base import SchemaBase
from .registry import SCHEMAS
from ..name import SchemaName

__all__ = ("Sample",)

Expand All @@ -30,7 +30,9 @@ class Sample(SchemaBase):
This should be set after instantiated.
ann_3ds (list[str]): List of foreign keys pointing the sample annotations.
This should be set after instantiated.
ann_3ds (list[str]): List of foreign keys pointing the object annotations.
ann_2ds (list[str]): List of foreign keys pointing the object annotations.
This should be set after instantiated.
surface_anns (list[str]): List of foreign keys pointing the surface annotations.
This should be set after instantiated.
"""

Expand All @@ -44,10 +46,11 @@ class Sample(SchemaBase):
data: dict[str, str] = field(default_factory=dict, init=False)
ann_3ds: list[str] = field(default_factory=list, init=False)
ann_2ds: list[str] = field(default_factory=list, init=False)
surface_anns: list[str] = field(default_factory=list, init=False)

@staticmethod
def shortcuts() -> tuple[str, str, str]:
return ("data", "ann_3ds", "ann_2ds")
def shortcuts() -> tuple[str, str, str, str]:
return ("data", "ann_3ds", "ann_2ds", "surface_ann_2ds")

@classmethod
def from_dict(cls, data: dict[str, Any]) -> Self:
Expand Down
36 changes: 30 additions & 6 deletions t4_devkit/schema/tables/surface_ann.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
from __future__ import annotations

from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any

import numpy as np
from typing_extensions import Self

from ..name import SchemaName
from .base import SchemaBase
from .object_ann import RLEMask
from .registry import SCHEMAS
from ..name import SchemaName

if TYPE_CHECKING:
from t4_devkit.typing import MaskType
from t4_devkit.typing import RoiType

__all__ = ("SurfaceAnn",)

Expand All @@ -24,14 +26,36 @@ class SurfaceAnn(SchemaBase):
token (str): Unique record identifier.
sample_data_token (str): Foreign key pointing to the sample data, which must be a keyframe image.
category_token (str): Foreign key pointing to the surface category.
mask (MaskType): Segmentation mask using the COCO format.
mask (RLEMask): Segmentation mask using the COCO format compressed by RLE.
"""

token: str
sample_data_token: str
category_token: str
mask: MaskType
mask: RLEMask

# shortcuts
category_name: str = field(init=False)

@staticmethod
def shortcuts() -> tuple[str]:
return ("category_name",)

@classmethod
def from_dict(cls, data: dict[str, Any]) -> Self:
return cls(**data)
new_data = data.copy()
new_data["mask"] = RLEMask(**data["mask"])
return cls(**new_data)

@property
def bbox(self) -> RoiType:
"""Return a bounding box corners calculated from polygon vertices.
Returns:
Given as [xmin, ymin, xmax, ymax].
"""
mask = self.mask.decode()
indices = np.where(mask == 1)
xmin, ymin = np.min(indices, axis=1)
xmax, ymax = np.max(indices, axis=1)
return xmin, ymin, xmax, ymax
93 changes: 83 additions & 10 deletions t4_devkit/tier4.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,10 @@ def __make_reverse_index__(self, verbose: bool) -> None:
category: Category = self.get("category", instance.category_token)
record.category_name = category.name

for record in self.surface_ann:
category: Category = self.get("category", record.category_token)
record.category_name = category.name

registered_channels: list[str] = []
for record in self.sample_data:
cs_record: CalibratedSensor = self.get(
Expand Down Expand Up @@ -216,6 +220,11 @@ def __make_reverse_index__(self, verbose: bool) -> None:
sample_record: Sample = self.get("sample", sd_record.sample_token)
sample_record.ann_2ds.append(ann_record.token)

for ann_record in self.surface_ann:
sd_record: SampleData = self.get("sample_data", ann_record.sample_data_token)
sample_record: Sample = self.get("sample", sd_record.sample_token)
sample_record.surface_anns.append(ann_record.token)

log_to_map: dict[str, str] = {}
for map_record in self.map:
for log_token in map_record.log_tokens:
Expand Down Expand Up @@ -1127,18 +1136,48 @@ def _render_annotation_2ds(
if max_timestamp_us < sample.timestamp:
break

if instance_token is not None:
boxes = []
for ann_token in sample.ann_2ds:
ann: ObjectAnn = self.get("object_ann", ann_token)
if ann.instance_token == instance_token:
boxes.append(self.get_box2d(ann_token))
break
else:
boxes = list(map(self.get_box2d, sample.ann_2ds))
boxes: list[Box2D] = []

# For segmentation masks
# TODO: declare specific class for segmentation mask in `dataclass`
camera_masks: dict[str, dict[str, list]] = {}

# Object Annotation
for ann_token in sample.ann_2ds:
ann: ObjectAnn = self.get("object_ann", ann_token)
box = self.get_box2d(ann_token)
boxes.append(box)

sample_data: SampleData = self.get("sample_data", ann.sample_data_token)
camera_masks = _append_mask(
camera_masks,
camera=sample_data.channel,
ann=ann,
class_id=self._label2id[ann.category_name],
uuid=box.uuid,
)

# Render 2D box
viewer.render_box2ds(us2sec(sample.timestamp), boxes)

# TODO: add support of rendering object/surface mask and keypoints
# Surface Annotation
for ann_token in sample.surface_anns:
sample_data: SampleData = self.get("sample_data", ann.sample_data_token)
ann: SurfaceAnn = self.get("surface_ann", ann_token)
camera_masks = _append_mask(
camera_masks,
camera=sample_data.channel,
ann=ann,
class_id=self._label2id[ann.category_name],
)

# Render 2D segmentation image
for camera, data in camera_masks.items():
viewer.render_segmentation2d(
seconds=us2sec(sample.timestamp), camera=camera, **data
)

# TODO: add support of rendering keypoints
current_sample_token = sample.next

def _render_sensor_calibration(self, viewer: Tier4Viewer, sample_data_token: str) -> None:
Expand All @@ -1154,3 +1193,37 @@ def _render_sensor_calibration(self, viewer: Tier4Viewer, sample_data_token: str
)
sensor: Sensor = self.get("sensor", calibrated_sensor.sensor_token)
viewer.render_calibration(sensor, calibrated_sensor)


def _append_mask(
camera_masks: dict[str, dict[str, list]],
camera: str,
ann: ObjectAnn | SurfaceAnn,
class_id: int,
uuid: str | None = None,
) -> dict[str, dict[str, list]]:
"""Append segmentation mask data from `ObjectAnn/SurfaceAnn`.
TODO:
This function should be removed after declaring specific dataclass for 2D segmentation.
Args:
camera_masks (dict[str, dict[str, list]]): Key-value data mapping camera name and mask data.
camera (str): Name of camera channel.
ann (ObjectAnn | SurfaceAnn): Annotation object.
class_id (int): Class ID.
uuid (str | None, optional): Unique instance identifier.
Returns:
dict[str, dict[str, list]]: Updated `camera_masks`.
"""
if camera in camera_masks:
camera_masks[camera]["masks"].append(ann.mask.decode())
camera_masks[camera]["class_ids"].append(class_id)
camera_masks[camera]["uuids"].append(uuid)
else:
camera_masks[camera] = {}
camera_masks[camera]["masks"] = [ann.mask.decode()]
camera_masks[camera]["class_ids"] = [class_id]
camera_masks[camera]["uuids"] = [class_id]
return camera_masks
2 changes: 0 additions & 2 deletions t4_devkit/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
"CamIntrinsicType",
"CamDistortionType",
"RoiType",
"MaskType",
"KeypointType",
)

Expand Down Expand Up @@ -54,5 +53,4 @@

# 2D
RoiType = NewType("RoiType", tuple[int, int, int, int]) # (xmin, ymin, xmax, ymax)
MaskType = NewType("MaskType", list[int])
KeypointType = NewType("KeypointType", NDArrayF64)
Loading

0 comments on commit e6dc14d

Please sign in to comment.