diff --git a/docs/apis/common.md b/docs/apis/common.md index 1d946d7..aae2712 100644 --- a/docs/apis/common.md +++ b/docs/apis/common.md @@ -6,6 +6,4 @@ ::: t4_devkit.common.io ::: t4_devkit.common.timestamp - -::: t4_devkit.common.transform diff --git a/docs/apis/dataclass.md b/docs/apis/dataclass.md index b01e8c9..521ff99 100644 --- a/docs/apis/dataclass.md +++ b/docs/apis/dataclass.md @@ -20,4 +20,6 @@ ::: t4_devkit.dataclass.trajectory +::: t4_devkit.dataclass.transform + diff --git a/docs/apis/filtering.md b/docs/apis/filtering.md new file mode 100644 index 0000000..d6d401d --- /dev/null +++ b/docs/apis/filtering.md @@ -0,0 +1,12 @@ +# `filtering` + + +::: t4_devkit.filtering.compose + +::: t4_devkit.filtering.functional + options: + filters: ["!BaseBoxFilter"] + show_bases: false + +::: t4_devkit.filtering.parameter + diff --git a/mkdocs.yaml b/mkdocs.yaml index 180c966..f85b96f 100644 --- a/mkdocs.yaml +++ b/mkdocs.yaml @@ -10,6 +10,7 @@ nav: - TIER IV: apis/tier4.md - Schema: apis/schema.md - DataClass: apis/dataclass.md + - Filtering: apis/filtering.md - Viewer: apis/viewer.md - Common: apis/common.md diff --git a/t4_devkit/dataclass/__init__.py b/t4_devkit/dataclass/__init__.py index f76005a..750780d 100644 --- a/t4_devkit/dataclass/__init__.py +++ b/t4_devkit/dataclass/__init__.py @@ -4,3 +4,4 @@ from .roi import * # noqa from .shape import * # noqa from .trajectory import * # noqa +from .transform import * # noqa diff --git a/t4_devkit/dataclass/box.py b/t4_devkit/dataclass/box.py index 101c29b..f7b6729 100644 --- a/t4_devkit/dataclass/box.py +++ b/t4_devkit/dataclass/box.py @@ -12,6 +12,7 @@ from .trajectory import to_trajectories if TYPE_CHECKING: + from t4_devkit.dataclass import HomogeneousMatrix from t4_devkit.typing import ( NDArrayF64, RotationType, @@ -26,7 +27,34 @@ from .trajectory import Trajectory -__all__ = ["Box3D", "Box2D", "BoxType"] +__all__ = ["Box3D", "Box2D", "BoxType", "distance_box"] + + +def distance_box(box: BoxType, tf_matrix: HomogeneousMatrix) -> float | None: + """Return a box distance from `base_link`. + + Args: + box (BoxType): A box. + tf_matrix (HomogeneousMatrix): Transformation matrix. + + Raises: + TypeError: Expecting type of box is `Box2D` or `Box3D`. + + Returns: + float | None: Return `None` if the type of box is `Box2D` and its `position` is `None`, + otherwise returns distance from `base_link`. + """ + if isinstance(box, Box2D) and box.position is None: + return None + + if isinstance(box, Box2D): + position = tf_matrix.transform(box.position) + elif isinstance(box, Box3D): + position, _ = tf_matrix.transform(box.position, box.rotation) + else: + raise TypeError(f"Unexpected box type: {type(box)}") + + return np.linalg.norm(position) @dataclass(eq=False) @@ -40,6 +68,10 @@ class BaseBox: uuid: str | None = field(default=None, kw_only=True) +# TODO: add intermediate class to represent the box state. +# >>> e.g.) box.as_state() -> BoxState + + @dataclass(eq=False) class Box3D(BaseBox): """A class to represent 3D box. diff --git a/t4_devkit/dataclass/label.py b/t4_devkit/dataclass/label.py index 87a076c..ab5c2e2 100644 --- a/t4_devkit/dataclass/label.py +++ b/t4_devkit/dataclass/label.py @@ -54,7 +54,7 @@ def from_name(cls, name: str) -> Self: assert name in cls.__members__, f"Unexpected label name: {name}" return cls.__members__[name] - def __eq__(self, other: LabelID | str) -> bool: + def __eq__(self, other: str | LabelID) -> bool: return self.name == other.upper() if isinstance(other, str) else self.name == other.name @@ -72,8 +72,8 @@ class SemanticLabel: original: str | None = field(default=None) attributes: list[str] = field(default_factory=list) - def __eq__(self, other: SemanticLabel) -> bool: - return self.label == other.label + def __eq__(self, other: str | SemanticLabel) -> bool: + return self.label == other if isinstance(other, str) else self.label == other.label # ===================== diff --git a/t4_devkit/common/transform.py b/t4_devkit/dataclass/transform.py similarity index 85% rename from t4_devkit/common/transform.py rename to t4_devkit/dataclass/transform.py index ea5408d..7acf0d3 100644 --- a/t4_devkit/common/transform.py +++ b/t4_devkit/dataclass/transform.py @@ -1,15 +1,49 @@ from __future__ import annotations +from dataclasses import dataclass, field from typing import TYPE_CHECKING, overload import numpy as np from pyquaternion import Quaternion from typing_extensions import Self +from t4_devkit.typing import NDArray, RotationType + if TYPE_CHECKING: - from t4_devkit.typing import ArrayLike, NDArray, RotationType + from t4_devkit.typing import ArrayLike + +__all__ = ["TransformBuffer", "HomogeneousMatrix", "TransformLike"] + + +@dataclass +class TransformBuffer: + buffer: dict[tuple[str, str], HomogeneousMatrix] = field(default_factory=dict, init=False) + + def set_transform(self, matrix: HomogeneousMatrix) -> None: + """Set transform matrix to the buffer. + Also, if its inverse transformation has not been registered, registers it too. + + Args: + matrix (HomogeneousMatrix): Transformation matrix. + """ + src = matrix.src + dst = matrix.dst + if (src, dst) not in self.buffer: + self.buffer[(src, dst)] = matrix + + if (dst, src) not in self.buffer: + self.buffer[(dst, src)] = matrix.inv() + + def lookup_transform(self, src: str, dst: str) -> HomogeneousMatrix | None: + if src == dst: + return HomogeneousMatrix(np.zeros(3), Quaternion(), src=src, dst=dst) + return self.buffer[(src, dst)] if (src, dst) in self.buffer else None + + def do_transform(self, src: str, dst: str, *args: TransformLike) -> TransformLike | None: + return self.buffer[(src, dst)].transform(args) if (src, dst) in self.buffer else None +@dataclass class HomogeneousMatrix: def __init__( self, @@ -214,6 +248,9 @@ def __transform_matrix(self, matrix: HomogeneousMatrix) -> HomogeneousMatrix: return matrix.dot(self) +TransformLike = NDArray | tuple[NDArray, RotationType] | HomogeneousMatrix + + def _extract_position_and_rotation_from_matrix( matrix: NDArray | HomogeneousMatrix, ) -> tuple[NDArray, Quaternion]: @@ -265,3 +302,4 @@ def _generate_homogeneous_matrix( matrix[:3, 3] = position matrix[:3, :3] = rotation.rotation_matrix return matrix + return matrix diff --git a/t4_devkit/filtering/__init__.py b/t4_devkit/filtering/__init__.py new file mode 100644 index 0000000..33118f8 --- /dev/null +++ b/t4_devkit/filtering/__init__.py @@ -0,0 +1,4 @@ +from .compose import BoxFilter +from .parameter import FilterParams + +__all__ = ["BoxFilter", "FilterParams"] diff --git a/t4_devkit/filtering/compose.py b/t4_devkit/filtering/compose.py new file mode 100644 index 0000000..29bd923 --- /dev/null +++ b/t4_devkit/filtering/compose.py @@ -0,0 +1,53 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Sequence + +from .functional import ( + FilterByDistance, + FilterByLabel, + FilterByNumPoints, + FilterByPosition, + FilterBySpeed, + FilterByUUID, +) +from .parameter import FilterParams + +if TYPE_CHECKING: + from t4_devkit.dataclass import BoxType, TransformBuffer + + from .functional import FilterLike + + +class BoxFilter: + """A class composes multiple filtering functions.""" + + def __init__(self, params: FilterParams, tf_buffer: TransformBuffer) -> None: + """Construct a new object. + + Args: + params (FilterParams): Filtering parameters. + tf_buffer (TransformBuffer): Transformation buffer. + """ + self.filters: list[FilterLike] = [ + FilterByLabel.from_params(params), + FilterByUUID.from_params(params), + FilterByDistance.from_params(params), + FilterByPosition.from_params(params), + FilterBySpeed.from_params(params), + FilterByNumPoints.from_params(params), + ] + + self.tf_buffer = tf_buffer + + def __call__(self, boxes: Sequence[BoxType]) -> list[BoxType]: + output: list[BoxType] = [] + + for box in boxes: + tf_matrix = self.tf_buffer.lookup_transform(box.frame_id, "base_link") + + is_ok = all(func(box, tf_matrix) for func in self.filters) + + if is_ok: + output.append(box) + + return output diff --git a/t4_devkit/filtering/functional.py b/t4_devkit/filtering/functional.py new file mode 100644 index 0000000..0ea5311 --- /dev/null +++ b/t4_devkit/filtering/functional.py @@ -0,0 +1,239 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Sequence, TypeVar + +import numpy as np +from typing_extensions import Self + +from t4_devkit.dataclass import Box2D, Box3D, HomogeneousMatrix, distance_box +from t4_devkit.filtering.parameter import FilterParams + +if TYPE_CHECKING: + from t4_devkit.dataclass import BoxType, SemanticLabel + + +__all__ = [ + "BaseBoxFilter", + "FilterByLabel", + "FilterByUUID", + "FilterByDistance", + "FilterByPosition", + "BoxFilterLike", +] + + +class BaseBoxFilter(ABC): + """Abstract base class of box filter functions.""" + + @classmethod + @abstractmethod + def from_params(cls, params: FilterParams) -> Self: + """Construct a new object from `FilterParams`. + + Args: + params (FilterParams): Filtering parameters. + + Returns: + A new self object. + """ + pass + + @abstractmethod + def __call__(self, box: BoxType, tf_matrix: HomogeneousMatrix | None = None) -> bool: + """Check whether the input box satisfies requirements. + + Args: + box (BoxType): Box. + tf_matrix (HomogeneousMatrix): Transformation matrix. + + Returns: + Return `True` if the box satisfies requirements. + """ + pass + + +class FilterByLabel(BaseBoxFilter): + """Filter a box by checking if the label of the box is included in specified labels. + + Note that, if `labels` is None all boxes pass through this filter. + """ + + def __init__(self, labels: Sequence[str | SemanticLabel] | None = None) -> None: + """Construct a new object. + + Args: + labels (Sequence[str | SemanticLabel] | None, optional): Sequence of target labels. + If `None`, this filter always returns `True`. + """ + super().__init__() + self.labels = labels + + @classmethod + def from_params(cls, params: FilterParams) -> Self: + return cls(params.labels) + + def __call__(self, box: BoxType, _tf_matrix: HomogeneousMatrix | None = None) -> bool: + if self.labels is None: + return True + + return box.semantic_label in self.labels + + +class FilterByUUID(BaseBoxFilter): + """Filter a box by checking if the uuid of the box is included in specified uuids. + + Note that, if `uuids` is None all boxes pass through this filter. + """ + + def __init__(self, uuids: Sequence[str] | None = None) -> None: + """Construct a new object. + + Args: + uuids (Sequence[str] | None, optional): Sequence of target uuids. + If `None`, this filter always returns `True`. + """ + super().__init__() + self.uuids = uuids + + @classmethod + def from_params(cls, params: FilterParams) -> Self: + return cls(params.uuids) + + def __call__(self, box: BoxType, _tf_matrix: HomogeneousMatrix | None = None) -> bool: + if self.uuids is None: + return True + + return box.uuid in self.uuids + + +class FilterByDistance(BaseBoxFilter): + """Filter a box by checking if the box is within the specified distance. + + Note that, the type box is `Box2D` and its `position` is None, + these boxes pass through this filter. + """ + + def __init__(self, min_distance: float, max_distance: float) -> None: + """Construct a new object. + + Args: + min_distance (float): Minimum distance from the ego [m]. + max_distance (float): Maximum distance from the ego [m]. + """ + super().__init__() + self.min_distance = min_distance + self.max_distance = max_distance + + @classmethod + def from_params(cls, params: FilterParams) -> Self: + return cls(params.min_distance, params.max_distance) + + def __call__(self, box: BoxType, tf_matrix: HomogeneousMatrix) -> bool: + box_distance = distance_box(box, tf_matrix) + + # box_distance is None, only if the box is 2D and its position is None. + if box_distance is None: + return True + else: + return self.min_distance < box_distance and box_distance < self.max_distance + + +class FilterByPosition(BaseBoxFilter): + """Filter a box by checking if the box xy position is within the specified xy position. + + Note that, the type box is `Box2D` and its `position` is None, + these boxes pass through this filter. + """ + + def __init__(self, min_xy: tuple[float, float], max_xy: tuple[float, float]) -> None: + """Construct a new object. + + Args: + min_xy (tuple[float, float]): Minimum xy position [m]. + max_xy (tuple[float, float]): Maximum xy position [m]. + """ + super().__init__() + self.min_xy = min_xy + self.max_xy = max_xy + + @classmethod + def from_params(cls, params: FilterParams) -> Self: + return cls(params.min_xy, params.max_xy) + + def __call__(self, box: BoxType, tf_matrix: HomogeneousMatrix) -> bool: + if isinstance(box, Box2D) and box.position is None: + return True + + if isinstance(box, Box2D): + position = tf_matrix.transform(box.position) + elif isinstance(box, Box3D): + position, _ = tf_matrix.transform(box.position, box.rotation) + else: + raise TypeError(f"Unexpected box type: {type(box)}") + + return np.all((self.min_xy < position[:2]) & (position[:2] < self.max_xy)) + + +class FilterBySpeed(BaseBoxFilter): + """Filter a 3D box by checking if the box speed is within the specified one. + + Note that, the type box is `Box2D`, or `Box3D` and its `velocity` is None, + these boxes pass through this filter. + """ + + def __init__(self, min_speed: float, max_speed: float) -> None: + """Construct a new object. + + Args: + min_speed (float): Minimum speed [m/s]. + max_speed (float): Maximum speed [m/s]. + """ + super().__init__() + self.min_speed = min_speed + self.max_speed = max_speed + + @classmethod + def from_params(cls, params: FilterParams) -> Self: + return cls(params.min_speed, params.max_speed) + + def __call__(self, box: BoxType, _tf_matrix: HomogeneousMatrix | None = None) -> bool: + if isinstance(box, Box2D): + return True + elif isinstance(box, Box3D) and box.velocity is None: + return True + else: + speed = np.linalg.norm(box.velocity) + return self.min_speed < speed and speed < self.max_speed + + +class FilterByNumPoints(BaseBoxFilter): + """Filter a 3D box by checking if the box includes points greater than the specified one. + + Note that, the type box is `Box2D`, or `Box3D` and its `num_points` is None, + these boxes pass through this filter. + """ + + def __init__(self, min_num_points: int = 0) -> None: + """Construct a new object. + + Args: + min_num_points (int, optional): The minimum number of points that a box should include. + """ + super().__init__() + self.min_num_points = min_num_points + + @classmethod + def from_params(cls, params: FilterParams) -> Self: + return cls(params.min_num_points) + + def __call__(self, box: BoxType, _tf_matrix: HomogeneousMatrix | None = None) -> bool: + if isinstance(box, Box2D): + return True + elif isinstance(box, Box3D) and box.num_points is None: + return True + else: + return self.min_num_points <= box.num_points + + +BoxFilterLike = TypeVar("BoxFilterLike", bound=BaseBoxFilter) diff --git a/t4_devkit/filtering/parameter.py b/t4_devkit/filtering/parameter.py new file mode 100644 index 0000000..922aebb --- /dev/null +++ b/t4_devkit/filtering/parameter.py @@ -0,0 +1,36 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Sequence + +import numpy as np + +if TYPE_CHECKING: + from t4_devkit.dataclass import SemanticLabel + + +@dataclass +class FilterParams: + """A dataclass to represent filtering parameters. + + Attributes: + labels (Sequence[str | SemanticLabel] | None, optional): Sequence of target labels. + uuids (Sequence[str] | None, optional): Sequence of target uuids. + min_distance (float, optional): Minimum distance from the ego [m]. + max_distance (float, optional): Maximum distance from the ego [m]. + min_xy (tuple[float, float], optional): Minimum xy position from the ego [m]. + min_xy (tuple[float, float], optional): Maximum xy position from the ego [m]. + min_speed (float, optional): Minimum speed [m/s]. + max_speed (float, optional): Maximum speed [m/s]. + min_num_points (int): The minimum number of points which the 3D box should include. + """ + + labels: Sequence[str | SemanticLabel] | None = field(default=None) + uuids: Sequence[str] | None = field(default=None) + min_distance: float = field(default=0.0) + max_distance: float = field(default=np.inf) + min_xy: tuple[float, float] = field(default=(-np.inf, -np.inf)) + max_xy: tuple[float, float] = field(default=(np.inf, np.inf)) + min_speed: float = field(default=0.0) + max_speed: float = field(default=np.inf) + min_num_points: int = field(default=0) diff --git a/tests/conftest.py b/tests/conftest.py index 8e039dd..985c597 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,16 @@ import pytest from pyquaternion import Quaternion -from t4_devkit.dataclass import Box2D, Box3D, LabelID, SemanticLabel, Shape, ShapeType +from t4_devkit.dataclass import ( + Box2D, + Box3D, + HomogeneousMatrix, + LabelID, + SemanticLabel, + Shape, + ShapeType, + TransformBuffer, +) @pytest.fixture(scope="module") @@ -118,3 +127,33 @@ def dummy_box2ds() -> list[Box2D]: uuid="pedestrian2d_1", ), ] + + +@pytest.fixture(scope="module") +def dummy_tf_buffer() -> TransformBuffer: + """Return a dummy transformation buffer. + + Returns: + Buffer includes `base_link` to `map` and `base_link` to `camera` transformation. + """ + tf_buffer = TransformBuffer() + + tf_buffer.set_transform( + HomogeneousMatrix( + [1.0, 1.0, 1.0], + Quaternion([0.0, 0.0, 0.0, 1.0]), + src="base_link", + dst="map", + ) + ) + + tf_buffer.set_transform( + HomogeneousMatrix( + [1.0, 1.0, 1.0], + Quaternion([0.0, 0.0, 0.0, 1.0]), + src="base_link", + dst="camera", + ) + ) + + return tf_buffer diff --git a/tests/dataclass/test_box.py b/tests/dataclass/test_box.py index bdcc1c3..e540270 100644 --- a/tests/dataclass/test_box.py +++ b/tests/dataclass/test_box.py @@ -1,8 +1,16 @@ +import math + import numpy as np +from t4_devkit.dataclass import distance_box + def test_box3d(dummy_box3d) -> None: - """Test `Box3D` class.""" + """Test `Box3D` class. + + Args: + dummy_box3d (Box3D): 3D box. + """ # test properties assert np.allclose(dummy_box3d.size, (1.0, 1.0, 1.0)) assert dummy_box3d.area == 1.0 @@ -26,7 +34,11 @@ def test_box3d(dummy_box3d) -> None: def test_box2d(dummy_box2d) -> None: - """Test `Box2D` class.""" + """Test `Box2D` class. + + Args: + dummy_box2d (Box2D): 2D box. + """ # test properties assert dummy_box2d.offset == (100, 100) assert dummy_box2d.size == (50, 50) @@ -34,3 +46,16 @@ def test_box2d(dummy_box2d) -> None: assert dummy_box2d.height == 50 assert dummy_box2d.center == (125, 125) assert dummy_box2d.area == 2500 + + +def test_distance_box(dummy_box3d, dummy_tf_buffer) -> None: + """Test `distance_box` function. + + Args: + dummy_box3d (Box3D): 3D box. + dummy_tf_buffer (TransformBuffer): Transformation buffer. + """ + tf_matrix = dummy_tf_buffer.lookup_transform(dummy_box3d.frame_id, "base_link") + distance = distance_box(dummy_box3d, tf_matrix) + + assert math.isclose(distance, np.linalg.norm([1.0, 1.0, 1.0])) diff --git a/tests/common/test_transform.py b/tests/dataclass/test_transform.py similarity index 97% rename from tests/common/test_transform.py rename to tests/dataclass/test_transform.py index 412733b..6907594 100644 --- a/tests/common/test_transform.py +++ b/tests/dataclass/test_transform.py @@ -2,7 +2,7 @@ import numpy as np -from t4_devkit.common.transform import HomogeneousMatrix +from t4_devkit.dataclass.transform import HomogeneousMatrix def test_homogeneous_matrix_transform(): diff --git a/tests/fitering/test_filter_compose.py b/tests/fitering/test_filter_compose.py new file mode 100644 index 0000000..2542fb2 --- /dev/null +++ b/tests/fitering/test_filter_compose.py @@ -0,0 +1,30 @@ +from t4_devkit.filtering import BoxFilter, FilterParams + + +def test_composite_filter(dummy_box3ds, dummy_box2ds, dummy_tf_buffer) -> None: + """Test `BoxFilter` compositing the box filters. + + Args: + dummy_box3ds (list[Box3D]): List of 3D boxes. + dummy_box2ds (list[Box2D]): List of 2D boxes. + dummy_tf_buffer (TransformBuffer): Transformation buffer. + """ + params = FilterParams( + labels=["car"], + uuids=["car3d_1", "car2d_1"], + min_distance=0.0, + max_distance=2.0, + min_xy=(0.0, 0.0), + max_xy=(2.0, 2.0), + min_speed=0.5, + max_speed=2.0, + min_num_points=0, + ) + + box_filter = BoxFilter(params, dummy_tf_buffer) + + answer3d = box_filter(dummy_box3ds) + answer2d = box_filter(dummy_box2ds) + + assert len(answer3d) == 1 + assert len(answer2d) == 1 diff --git a/tests/fitering/test_filter_function.py b/tests/fitering/test_filter_function.py new file mode 100644 index 0000000..522da66 --- /dev/null +++ b/tests/fitering/test_filter_function.py @@ -0,0 +1,118 @@ +from t4_devkit.filtering.functional import ( + FilterByDistance, + FilterByLabel, + FilterByNumPoints, + FilterByPosition, + FilterBySpeed, + FilterByUUID, +) + + +def test_filter_by_label(dummy_box3ds, dummy_box2ds) -> None: + """Test `FilterByLabel` class. + + Args: + dummy_box3ds (list[Box3D]): List of 3D boxes. + dummy_box2ds (list[Box2D]): List of 2D boxes. + """ + box_filter = FilterByLabel(labels=["car"]) + + answer3d = [box for box in dummy_box3ds if box_filter(box)] + answer2d = [box for box in dummy_box2ds if box_filter(box)] + + assert len(answer3d) == 1 + assert len(answer2d) == 1 + + +def test_filter_by_uuid(dummy_box3ds, dummy_box2ds) -> None: + """Test `FilterByUUID` class. + + Args: + dummy_box3ds (list[Box3D]): List of 3D boxes. + dummy_box2ds (list[Box2D]): List of 2D boxes. + """ + box_filter = FilterByUUID(uuids=["car3d_1", "car2d_1"]) + + answer3d = [box for box in dummy_box3ds if box_filter(box)] + answer2d = [box for box in dummy_box2ds if box_filter(box)] + + assert len(answer3d) == 1 + assert len(answer2d) == 1 + + +def test_filter_by_distance(dummy_box3ds, dummy_box2ds, dummy_tf_buffer) -> None: + """Test `FilterByDistance`. + + Args: + dummy_box3ds (list[Box3D]): List of 3D boxes. + dummy_box2ds (list[Box2D]): List of 2D boxes. + dummy_tf_buffer (TransformBuffer): Transformation buffer. + """ + box_filter = FilterByDistance(min_distance=0.0, max_distance=10.0) + + answer3d = [ + box + for box in dummy_box3ds + if box_filter(box, dummy_tf_buffer.lookup_transform(box.frame_id, "base_link")) + ] + + answer2d = [ + box + for box in dummy_box2ds + if box_filter(box, dummy_tf_buffer.lookup_transform(box.frame_id, "base_link")) + ] + + assert len(answer3d) == 3 + assert len(answer2d) == 3 + + +def test_filter_by_position(dummy_box3ds, dummy_box2ds, dummy_tf_buffer) -> None: + """Test `FilterByPosition`. + + Args: + dummy_box3ds (list[Box3D]): List of 3D boxes. + dummy_box2ds (list[Box2D]): List of 2D boxes. + dummy_tf_buffer (TransformBuffer): Transformation buffer. + """ + box_filter = FilterByPosition(min_xy=(0.0, 0.0), max_xy=(10.0, 10.0)) + + answer3d = [ + box + for box in dummy_box3ds + if box_filter(box, dummy_tf_buffer.lookup_transform(box.frame_id, "base_link")) + ] + + answer2d = [ + box + for box in dummy_box2ds + if box_filter(box, dummy_tf_buffer.lookup_transform(box.frame_id, "base_link")) + ] + + assert len(answer3d) == 1 + assert len(answer2d) == 3 + + +def test_filter_by_speed(dummy_box3ds) -> None: + """Test `FilterBySpeed`. + + Args: + dummy_box3ds (list[Box3D]): List of 3D boxes. + """ + box_filter = FilterBySpeed(min_speed=0.5, max_speed=2.0) + + answer = [box for box in dummy_box3ds if box_filter(box)] + + assert len(answer) == 3 + + +def test_filter_by_num_points(dummy_box3ds) -> None: + """Test `FilterByNumPoints`. + + Args: + dummy_box3ds (list[Box3D]): List of 3D boxes. + """ + box_filter = FilterByNumPoints(min_num_points=0) + + answer = [box for box in dummy_box3ds if box_filter(box)] + + assert len(answer) == 3