Skip to content

Commit

Permalink
feat: add filtering functions
Browse files Browse the repository at this point in the history
Signed-off-by: ktro2828 <[email protected]>
  • Loading branch information
ktro2828 committed Oct 21, 2024
1 parent 38d6f1d commit 530b288
Show file tree
Hide file tree
Showing 17 changed files with 639 additions and 11 deletions.
2 changes: 0 additions & 2 deletions docs/apis/common.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,4 @@
::: t4_devkit.common.io

::: t4_devkit.common.timestamp

::: t4_devkit.common.transform
<!-- prettier-ignore-end -->
2 changes: 2 additions & 0 deletions docs/apis/dataclass.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,6 @@

::: t4_devkit.dataclass.trajectory

::: t4_devkit.dataclass.transform

<!-- prettier-ignore-end -->
12 changes: 12 additions & 0 deletions docs/apis/filtering.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# `filtering`

<!-- prettier-ignore-start -->
::: t4_devkit.filtering.compose

::: t4_devkit.filtering.functional
options:
filters: ["!BaseBoxFilter"]
show_bases: false

::: t4_devkit.filtering.parameter
<!-- prettier-ignore-end -->
1 change: 1 addition & 0 deletions mkdocs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ nav:
- TIER IV: apis/tier4.md
- Schema: apis/schema.md
- DataClass: apis/dataclass.md
- Filtering: apis/filtering.md
- Viewer: apis/viewer.md
- Common: apis/common.md

Expand Down
1 change: 1 addition & 0 deletions t4_devkit/dataclass/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
from .roi import * # noqa
from .shape import * # noqa
from .trajectory import * # noqa
from .transform import * # noqa
34 changes: 33 additions & 1 deletion t4_devkit/dataclass/box.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .trajectory import to_trajectories

if TYPE_CHECKING:
from t4_devkit.dataclass import HomogeneousMatrix
from t4_devkit.typing import (
NDArrayF64,
RotationType,
Expand All @@ -26,7 +27,34 @@
from .trajectory import Trajectory


__all__ = ["Box3D", "Box2D", "BoxType"]
__all__ = ["Box3D", "Box2D", "BoxType", "distance_box"]


def distance_box(box: BoxType, tf_matrix: HomogeneousMatrix) -> float | None:
"""Return a box distance from `base_link`.
Args:
box (BoxType): A box.
tf_matrix (HomogeneousMatrix): Transformation matrix.
Raises:
TypeError: Expecting type of box is `Box2D` or `Box3D`.
Returns:
float | None: Return `None` if the type of box is `Box2D` and its `position` is `None`,
otherwise returns distance from `base_link`.
"""
if isinstance(box, Box2D) and box.position is None:
return None

if isinstance(box, Box2D):
position = tf_matrix.transform(box.position)
elif isinstance(box, Box3D):
position, _ = tf_matrix.transform(box.position, box.rotation)
else:
raise TypeError(f"Unexpected box type: {type(box)}")

return np.linalg.norm(position)


@dataclass(eq=False)
Expand All @@ -40,6 +68,10 @@ class BaseBox:
uuid: str | None = field(default=None, kw_only=True)


# TODO: add intermediate class to represent the box state.
# >>> e.g.) box.as_state() -> BoxState


@dataclass(eq=False)
class Box3D(BaseBox):
"""A class to represent 3D box.
Expand Down
6 changes: 3 additions & 3 deletions t4_devkit/dataclass/label.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def from_name(cls, name: str) -> Self:
assert name in cls.__members__, f"Unexpected label name: {name}"
return cls.__members__[name]

def __eq__(self, other: LabelID | str) -> bool:
def __eq__(self, other: str | LabelID) -> bool:
return self.name == other.upper() if isinstance(other, str) else self.name == other.name


Expand All @@ -72,8 +72,8 @@ class SemanticLabel:
original: str | None = field(default=None)
attributes: list[str] = field(default_factory=list)

def __eq__(self, other: SemanticLabel) -> bool:
return self.label == other.label
def __eq__(self, other: str | SemanticLabel) -> bool:
return self.label == other if isinstance(other, str) else self.label == other.label


# =====================
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,49 @@
from __future__ import annotations

from dataclasses import dataclass, field
from typing import TYPE_CHECKING, overload

import numpy as np
from pyquaternion import Quaternion
from typing_extensions import Self

from t4_devkit.typing import NDArray, RotationType

if TYPE_CHECKING:
from t4_devkit.typing import ArrayLike, NDArray, RotationType
from t4_devkit.typing import ArrayLike

__all__ = ["TransformBuffer", "HomogeneousMatrix", "TransformLike"]


@dataclass
class TransformBuffer:
buffer: dict[tuple[str, str], HomogeneousMatrix] = field(default_factory=dict, init=False)

def set_transform(self, matrix: HomogeneousMatrix) -> None:
"""Set transform matrix to the buffer.
Also, if its inverse transformation has not been registered, registers it too.
Args:
matrix (HomogeneousMatrix): Transformation matrix.
"""
src = matrix.src
dst = matrix.dst
if (src, dst) not in self.buffer:
self.buffer[(src, dst)] = matrix

if (dst, src) not in self.buffer:
self.buffer[(dst, src)] = matrix.inv()

def lookup_transform(self, src: str, dst: str) -> HomogeneousMatrix | None:
if src == dst:
return HomogeneousMatrix(np.zeros(3), Quaternion(), src=src, dst=dst)
return self.buffer[(src, dst)] if (src, dst) in self.buffer else None

def do_transform(self, src: str, dst: str, *args: TransformLike) -> TransformLike | None:
return self.buffer[(src, dst)].transform(args) if (src, dst) in self.buffer else None


@dataclass
class HomogeneousMatrix:
def __init__(
self,
Expand Down Expand Up @@ -214,6 +248,9 @@ def __transform_matrix(self, matrix: HomogeneousMatrix) -> HomogeneousMatrix:
return matrix.dot(self)


TransformLike = NDArray | tuple[NDArray, RotationType] | HomogeneousMatrix


def _extract_position_and_rotation_from_matrix(
matrix: NDArray | HomogeneousMatrix,
) -> tuple[NDArray, Quaternion]:
Expand Down Expand Up @@ -265,3 +302,4 @@ def _generate_homogeneous_matrix(
matrix[:3, 3] = position
matrix[:3, :3] = rotation.rotation_matrix
return matrix
return matrix
4 changes: 4 additions & 0 deletions t4_devkit/filtering/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .compose import BoxFilter
from .parameter import FilterParams

__all__ = ["BoxFilter", "FilterParams"]
53 changes: 53 additions & 0 deletions t4_devkit/filtering/compose.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Sequence

from .functional import (
FilterByDistance,
FilterByLabel,
FilterByNumPoints,
FilterByPosition,
FilterBySpeed,
FilterByUUID,
)
from .parameter import FilterParams

if TYPE_CHECKING:
from t4_devkit.dataclass import BoxType, TransformBuffer

from .functional import FilterLike


class BoxFilter:
"""A class composes multiple filtering functions."""

def __init__(self, params: FilterParams, tf_buffer: TransformBuffer) -> None:
"""Construct a new object.
Args:
params (FilterParams): Filtering parameters.
tf_buffer (TransformBuffer): Transformation buffer.
"""
self.filters: list[FilterLike] = [
FilterByLabel.from_params(params),
FilterByUUID.from_params(params),
FilterByDistance.from_params(params),
FilterByPosition.from_params(params),
FilterBySpeed.from_params(params),
FilterByNumPoints.from_params(params),
]

self.tf_buffer = tf_buffer

def __call__(self, boxes: Sequence[BoxType]) -> list[BoxType]:
output: list[BoxType] = []

for box in boxes:
tf_matrix = self.tf_buffer.lookup_transform(box.frame_id, "base_link")

is_ok = all(func(box, tf_matrix) for func in self.filters)

if is_ok:
output.append(box)

return output
Loading

0 comments on commit 530b288

Please sign in to comment.