Skip to content

Commit

Permalink
refactor: replace dataclasses to attrs
Browse files Browse the repository at this point in the history
Signed-off-by: ktro2828 <[email protected]>
  • Loading branch information
ktro2828 committed Nov 12, 2024
1 parent 042d7bf commit 0ea47a2
Show file tree
Hide file tree
Showing 29 changed files with 213 additions and 456 deletions.
27 changes: 27 additions & 0 deletions t4_devkit/common/converter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import numpy as np
from pyquaternion import Quaternion

if TYPE_CHECKING:
from t4_devkit.typing import ArrayLike, NDArray

__all__ = ["as_quaternion"]


def as_quaternion(value: ArrayLike | NDArray) -> Quaternion:
"""Convert input rotation like array to `Quaternion`.
Args:
value (ArrayLike | NDArray): Rotation matrix or quaternion.
Returns:
Quaternion: Converted instance.
"""
return (
Quaternion(matrix=value)
if isinstance(value, np.ndarray) and value.ndim == 2
else Quaternion(value)
)
2 changes: 1 addition & 1 deletion t4_devkit/common/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import json
from typing import Any

__all__ = ("load_json",)
__all__ = ("load_json", "save_json")


def load_json(filename: str) -> Any:
Expand Down
36 changes: 12 additions & 24 deletions t4_devkit/dataclass/box.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from __future__ import annotations

from dataclasses import dataclass, field
from typing import TYPE_CHECKING, TypeVar

import numpy as np
from pyquaternion import Quaternion
from attrs import define, field
from attrs.converters import optional
from shapely.geometry import Polygon
from typing_extensions import Self

from t4_devkit.common.converter import as_quaternion

from .roi import Roi
from .trajectory import to_trajectories

Expand Down Expand Up @@ -57,7 +59,7 @@ def distance_box(box: BoxType, tf_matrix: HomogeneousMatrix) -> float | None:
return np.linalg.norm(position)


@dataclass(eq=False)
@define(eq=False)
class BaseBox:
"""Abstract base class for box objects."""

Expand All @@ -72,7 +74,7 @@ class BaseBox:
# >>> e.g.) box.as_state() -> BoxState


@dataclass(eq=False)
@define(eq=False)
class Box3D(BaseBox):
"""A class to represent 3D box.
Expand Down Expand Up @@ -109,25 +111,15 @@ class Box3D(BaseBox):
... )
"""

position: TranslationType
rotation: RotationType
position: TranslationType = field(converter=np.asarray)
rotation: RotationType = field(converter=as_quaternion)
shape: Shape
velocity: VelocityType | None = field(default=None)
velocity: VelocityType | None = field(default=None, converter=optional(np.asarray))
num_points: int | None = field(default=None)

# additional attributes: set by `with_**`
future: list[Trajectory] | None = field(default=None, init=False)

def __post_init__(self) -> None:
if not isinstance(self.position, np.ndarray):
self.position = np.array(self.position)

if not isinstance(self.rotation, Quaternion):
self.rotation = Quaternion(self.rotation)

if self.velocity is not None and not isinstance(self.velocity, np.ndarray):
self.velocity = np.array(self.velocity)

def with_future(
self,
waypoints: list[TrajectoryType],
Expand Down Expand Up @@ -195,7 +187,7 @@ def corners(self, box_scale: float = 1.0) -> NDArrayF64:
return np.dot(self.rotation.rotation_matrix, corners).T + self.position


@dataclass(eq=False)
@define(eq=False)
class Box2D(BaseBox):
"""A class to represent 2D box.
Expand All @@ -222,15 +214,11 @@ class Box2D(BaseBox):
>>> box2d = box2d.with_position(position=(1.0, 1.0, 1.0))
"""

roi: Roi | None = field(default=None)
roi: Roi | None = field(default=None, converter=lambda x: None if x is None else Roi(x))

# additional attributes: set by `with_**`
position: TranslationType | None = field(default=None, init=False)

def __post_init__(self) -> None:
if self.roi is not None and not isinstance(self.roi, Roi):
self.roi = Roi(self.roi)

def with_position(self, position: TranslationType) -> Self:
"""Return a self instance setting `position` attribute.
Expand All @@ -240,7 +228,7 @@ def with_position(self, position: TranslationType) -> Self:
Returns:
Self instance after setting `position`.
"""
self.position = np.array(position) if not isinstance(position, np.ndarray) else position
self.position = np.asarray(position)
return self

def __eq__(self, other: Box2D | None) -> bool:
Expand Down
6 changes: 3 additions & 3 deletions t4_devkit/dataclass/label.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from __future__ import annotations

import warnings
from dataclasses import dataclass, field
from enum import Enum, auto, unique

from attrs import define, field
from typing_extensions import Self

__all__ = ["LabelID", "SemanticLabel", "convert_label"]
Expand Down Expand Up @@ -58,7 +58,7 @@ def __eq__(self, other: str | LabelID) -> bool:
return self.name == other.upper() if isinstance(other, str) else self.name == other.name


@dataclass(frozen=True, eq=False)
@define(frozen=True, eq=False)
class SemanticLabel:
"""A dataclass to represent semantic labels.
Expand All @@ -70,7 +70,7 @@ class SemanticLabel:

label: LabelID
original: str | None = field(default=None)
attributes: list[str] = field(default_factory=list)
attributes: list[str] = field(factory=list)

def __eq__(self, other: str | SemanticLabel) -> bool:
return self.label == other if isinstance(other, str) else self.label == other.label
Expand Down
22 changes: 13 additions & 9 deletions t4_devkit/dataclass/pointcloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

import struct
from abc import abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, ClassVar, TypeVar

import numpy as np
from attrs import define, field

if TYPE_CHECKING:
from typing_extensions import Self
Expand All @@ -21,14 +21,18 @@
]


@dataclass
@define
class PointCloud:
"""Abstract base dataclass for pointcloud data."""

points: NDArrayFloat
points: NDArrayFloat = field(converter=np.asarray)

def __post_init__(self) -> None:
assert self.points.shape[0] == self.num_dims()
@points.validator
def check_dims(self, attribute, value) -> None:
if value.shape[0] != self.num_dims():
raise ValueError(
f"Expected point dimension is {self.num_dims()}, but got {value.shape[0]}"
)

@staticmethod
@abstractmethod
Expand Down Expand Up @@ -74,7 +78,7 @@ def transform(self, matrix: NDArrayFloat) -> None:
)[:3, :]


@dataclass
@define
class LidarPointCloud(PointCloud):
"""A dataclass to represent lidar pointcloud."""

Expand All @@ -91,7 +95,7 @@ def from_file(cls, filepath: str) -> Self:
return cls(points.T)


@dataclass
@define
class RadarPointCloud(PointCloud):
# class variables
invalid_states: ClassVar[list[int]] = [0]
Expand Down Expand Up @@ -188,9 +192,9 @@ def from_file(
return cls(points)


@dataclass
@define
class SegmentationPointCloud(PointCloud):
labels: NDArrayU8
labels: NDArrayU8 = field(converter=lambda x: np.asarray(x, dtype=np.uint8))

@staticmethod
def num_dims() -> int:
Expand Down
10 changes: 4 additions & 6 deletions t4_devkit/dataclass/roi.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,24 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING

from attrs import define, field

if TYPE_CHECKING:
from t4_devkit.typing import RoiType

__all__ = ["Roi"]


@dataclass
@define
class Roi:
roi: RoiType
roi: RoiType = field(converter=tuple)

def __post_init__(self) -> None:
assert len(self.roi) == 4, (
"Expected roi is (x, y, width, height), " f"but got length with {len(self.roi)}."
)

if not isinstance(self.roi, tuple):
self.roi = tuple(self.roi)

@property
def offset(self) -> tuple[int, int]:
return self.roi[:2]
Expand Down
11 changes: 4 additions & 7 deletions t4_devkit/dataclass/shape.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from __future__ import annotations

from dataclasses import dataclass, field
from enum import Enum, auto, unique
from typing import TYPE_CHECKING

import numpy as np
from attrs import define, field
from shapely.geometry import Polygon
from typing_extensions import Self

Expand Down Expand Up @@ -35,7 +35,7 @@ def from_name(cls, name: str) -> Self:
return cls.__members__[name]


@dataclass
@define
class Shape:
"""A dataclass to represent the 3D box shape.
Expand All @@ -47,13 +47,10 @@ class Shape:
"""

shape_type: ShapeType
size: SizeType
size: SizeType = field(converter=np.asarray)
footprint: Polygon = field(default=None)

def __post_init__(self) -> None:
if not isinstance(self.size, np.ndarray):
self.size = np.array(self.size)

def __attrs_post_init__(self) -> None:
if self.shape_type == ShapeType.POLYGON and self.footprint is None:
raise ValueError("`footprint` must be specified for `POLYGON`.")

Expand Down
14 changes: 6 additions & 8 deletions t4_devkit/dataclass/trajectory.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
from __future__ import annotations

from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Generator

import numpy as np
from attrs import define, field

if TYPE_CHECKING:
from t4_devkit.typing import TrajectoryType, TranslationType

__all__ = ["Trajectory", "to_trajectories"]


@dataclass
@define
class Trajectory:
"""A dataclass to represent trajectory.
Expand Down Expand Up @@ -41,14 +41,12 @@ class Trajectory:
[2. 2. 2.]
"""

waypoints: TrajectoryType
waypoints: TrajectoryType = field(converter=np.asarray)
confidence: float = field(default=1.0)

def __post_init__(self) -> None:
if not isinstance(self.waypoints, np.ndarray):
self.waypoints = np.array(self.waypoints)

assert self.waypoints.shape[1] == 3
def __attrs_post_init__(self) -> None:
if self.waypoints.shape[1] != 3:
raise ValueError("Trajectory dimension must be 3.")

def __len__(self) -> int:
return len(self.waypoints)
Expand Down
Loading

0 comments on commit 0ea47a2

Please sign in to comment.