From 45ab4e01f2fd9d6103492c42f54224f2e21331e6 Mon Sep 17 00:00:00 2001 From: Kyle Vedder Date: Thu, 25 Jul 2024 19:31:31 -0400 Subject: [PATCH 1/2] Added box loader --- .../datasets/argoverse2/__init__.py | 10 ++- .../argoverse2/argoverse_box_annotations.py | 81 +++++++++++++++++++ .../argoverse2/argoverse_scene_flow.py | 2 +- .../datasets/argoverse2/dataset.py | 9 ++- .../datastructures/__init__.py | 2 + .../datastructures/dataclasses.py | 21 +++++ tests/datasets/argoverse2/av2_box_tests.py | 29 +++++++ tests/datasets/argoverse2/av2_small_tests.py | 1 + tests/setup.sh | 12 +-- 9 files changed, 155 insertions(+), 12 deletions(-) create mode 100644 bucketed_scene_flow_eval/datasets/argoverse2/argoverse_box_annotations.py create mode 100644 tests/datasets/argoverse2/av2_box_tests.py diff --git a/bucketed_scene_flow_eval/datasets/argoverse2/__init__.py b/bucketed_scene_flow_eval/datasets/argoverse2/__init__.py index d973aa7..baa964e 100644 --- a/bucketed_scene_flow_eval/datasets/argoverse2/__init__.py +++ b/bucketed_scene_flow_eval/datasets/argoverse2/__init__.py @@ -1,13 +1,15 @@ -from .argoverse_raw_data import ( - ArgoverseRawSequence, - ArgoverseRawSequenceLoader, -) +from .argoverse_raw_data import ArgoverseRawSequence, ArgoverseRawSequenceLoader from .argoverse_scene_flow import ( ArgoverseNoFlowSequence, ArgoverseNoFlowSequenceLoader, ArgoverseSceneFlowSequence, ArgoverseSceneFlowSequenceLoader, ) + +from .argoverse_box_annotations import ( + ArgoverseBoxAnnotationSequence, + ArgoverseBoxAnnotationSequenceLoader, +) from .dataset import Argoverse2CausalSceneFlow, Argoverse2NonCausalSceneFlow __all__ = [ diff --git a/bucketed_scene_flow_eval/datasets/argoverse2/argoverse_box_annotations.py b/bucketed_scene_flow_eval/datasets/argoverse2/argoverse_box_annotations.py new file mode 100644 index 0000000..f382b9b --- /dev/null +++ b/bucketed_scene_flow_eval/datasets/argoverse2/argoverse_box_annotations.py @@ -0,0 +1,81 @@ +from dataclasses import dataclass +from pathlib import Path + +from bucketed_scene_flow_eval.datastructures import ( + SE3, + BoundingBox, + TimeSyncedAVLidarData, + TimeSyncedSceneFlowBoxFrame, + TimeSyncedSceneFlowFrame, +) +from bucketed_scene_flow_eval.utils import load_feather + +from .argoverse_scene_flow import ArgoverseNoFlowSequence, ArgoverseNoFlowSequenceLoader + + +class ArgoverseBoxAnnotationSequence(ArgoverseNoFlowSequence): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.timestamp_to_boxes = self._prep_bbox_annotations() + + def _prep_bbox_annotations(self) -> dict[int, list[BoundingBox]]: + annotations_file = self.dataset_dir / "annotations.feather" + assert annotations_file.exists(), f"Annotations file {annotations_file} does not exist" + annotation_df = load_feather(annotations_file) + # Index(['timestamp_ns', 'track_uuid', 'category', 'length_m', 'width_m', + # 'height_m', 'qw', 'qx', 'qy', 'qz', 'tx_m', 'ty_m', 'tz_m', + # 'num_interior_pts'], + # dtype='object') + + # Convert to dictionary keyed by timestamp_ns int + timestamp_to_annotations: dict[int, list[BoundingBox]] = {} + for _, row in annotation_df.iterrows(): + timestamp_ns = row["timestamp_ns"] + if timestamp_ns not in timestamp_to_annotations: + timestamp_to_annotations[timestamp_ns] = [] + pose = SE3.from_rot_w_x_y_z_translation_x_y_z( + row["qw"], + row["qx"], + row["qy"], + row["qz"], + row["tx_m"], + row["ty_m"], + row["tz_m"], + ) + timestamp_to_annotations[timestamp_ns].append( + BoundingBox( + pose=pose, + length=row["length_m"], + width=row["width_m"], + height=row["height_m"], + track_uuid=row["track_uuid"], + category=row["category"], + ) + ) + return timestamp_to_annotations + + def load( + self, idx: int, relative_to_idx: int, with_flow: bool = False + ) -> tuple[TimeSyncedSceneFlowBoxFrame, TimeSyncedAVLidarData]: + scene_flow_frame, lidar_data = super().load(idx, relative_to_idx, with_flow) + timestamp = self.timestamp_list[idx] + boxes = self.timestamp_to_boxes.get(timestamp, []) + return TimeSyncedSceneFlowBoxFrame(**vars(scene_flow_frame), boxes=boxes), lidar_data + + +class ArgoverseBoxAnnotationSequenceLoader(ArgoverseNoFlowSequenceLoader): + + def _load_sequence_uncached(self, sequence_id: str) -> ArgoverseBoxAnnotationSequence: + assert ( + sequence_id in self.sequence_id_to_raw_data + ), f"sequence_id {sequence_id} does not exist" + return ArgoverseBoxAnnotationSequence( + sequence_id, + self.sequence_id_to_raw_data[sequence_id], + self.sequence_id_to_raw_data[sequence_id], + with_classes=False, + **self.load_sequence_kwargs, + ) + + def cache_folder_name(self) -> str: + return f"av2_box_data_use_gt_flow_{self.use_gt_flow}_raw_data_path_{self.raw_data_path}_No_flow_data_path" diff --git a/bucketed_scene_flow_eval/datasets/argoverse2/argoverse_scene_flow.py b/bucketed_scene_flow_eval/datasets/argoverse2/argoverse_scene_flow.py index fb1dbe6..97e1695 100644 --- a/bucketed_scene_flow_eval/datasets/argoverse2/argoverse_scene_flow.py +++ b/bucketed_scene_flow_eval/datasets/argoverse2/argoverse_scene_flow.py @@ -359,7 +359,7 @@ def _load_flow_feather( raise NotImplementedError("No flow data available for ArgoverseNoFlowSequence") def load( - self, idx: int, relative_to_idx: int, with_flow: bool = True + self, idx: int, relative_to_idx: int, with_flow: bool = False ) -> tuple[TimeSyncedSceneFlowFrame, TimeSyncedAVLidarData]: return super().load(idx, relative_to_idx, with_flow=False) diff --git a/bucketed_scene_flow_eval/datasets/argoverse2/dataset.py b/bucketed_scene_flow_eval/datasets/argoverse2/dataset.py index e0f03cf..6310561 100644 --- a/bucketed_scene_flow_eval/datasets/argoverse2/dataset.py +++ b/bucketed_scene_flow_eval/datasets/argoverse2/dataset.py @@ -14,6 +14,7 @@ NonCausalSeqLoaderDataset, ) +from .argoverse_box_annotations import ArgoverseBoxAnnotationSequenceLoader from .argoverse_raw_data import DEFAULT_POINT_CLOUD_RANGE, PointCloudRange from .argoverse_scene_flow import ( CATEGORY_MAP, @@ -113,10 +114,16 @@ def __init__( eval_type: str = "bucketed_epe", eval_args=dict(), use_cache=True, + load_boxes: bool = False, load_flow: bool = True, **kwargs, ) -> None: - if load_flow: + if load_boxes: + self.sequence_loader = ArgoverseBoxAnnotationSequenceLoader( + raw_data_path=root_dir, + **kwargs, + ) + elif load_flow: self.sequence_loader = ArgoverseSceneFlowSequenceLoader( raw_data_path=root_dir, use_gt_flow=use_gt_flow, diff --git a/bucketed_scene_flow_eval/datastructures/__init__.py b/bucketed_scene_flow_eval/datastructures/__init__.py index 77f7f5f..dc82e52 100644 --- a/bucketed_scene_flow_eval/datastructures/__init__.py +++ b/bucketed_scene_flow_eval/datastructures/__init__.py @@ -1,5 +1,6 @@ from .camera_projection import CameraModel, CameraProjection from .dataclasses import ( + BoundingBox, EgoLidarFlow, MaskArray, PointCloudFrame, @@ -12,6 +13,7 @@ TimeSyncedAVLidarData, TimeSyncedBaseAuxilaryData, TimeSyncedRawFrame, + TimeSyncedSceneFlowBoxFrame, TimeSyncedSceneFlowFrame, VectorArray, ) diff --git a/bucketed_scene_flow_eval/datastructures/dataclasses.py b/bucketed_scene_flow_eval/datastructures/dataclasses.py index 018c75d..6d00916 100644 --- a/bucketed_scene_flow_eval/datastructures/dataclasses.py +++ b/bucketed_scene_flow_eval/datastructures/dataclasses.py @@ -247,6 +247,16 @@ def __len__(self) -> int: return len(self.lookup) +@dataclass +class BoundingBox: + pose: SE3 + length: float + width: float + height: float + track_uuid: str + category: str + + @dataclass(kw_only=True) class TimeSyncedBaseAuxilaryData: pass @@ -285,3 +295,14 @@ def __post_init__(self): assert len(self.flow.full_flow) == len( self.pc.full_pc ), f"flow and pc must have the same length, got {len(self.flow.full_flow)} and {len(self.pc.full_pc)}" + + +@dataclass(kw_only=True) +class TimeSyncedSceneFlowBoxFrame(TimeSyncedSceneFlowFrame): + boxes: list[BoundingBox] + + def __post_init__(self): + assert isinstance(self.boxes, list), f"boxes must be a list, got {type(self.boxes)}" + assert all( + isinstance(box, BoundingBox) for box in self.boxes + ), f"all boxes must be BoundingBox objects, got {self.boxes}" diff --git a/tests/datasets/argoverse2/av2_box_tests.py b/tests/datasets/argoverse2/av2_box_tests.py new file mode 100644 index 0000000..18ad2a7 --- /dev/null +++ b/tests/datasets/argoverse2/av2_box_tests.py @@ -0,0 +1,29 @@ +from pathlib import Path + +import pytest + +from bucketed_scene_flow_eval.datasets.argoverse2 import ( + ArgoverseBoxAnnotationSequenceLoader, + ArgoverseSceneFlowSequenceLoader, +) + + +@pytest.fixture +def av2_box_sequence_loader() -> ArgoverseBoxAnnotationSequenceLoader: + return ArgoverseBoxAnnotationSequenceLoader( + raw_data_path=Path("/tmp/argoverse2_small/val"), + ) + + +def test_load_box_sequence_length( + av2_box_sequence_loader: ArgoverseBoxAnnotationSequenceLoader, +): + sequence = av2_box_sequence_loader.load_sequence("02678d04-cc9f-3148-9f95-1ba66347dff9") + assert len(sequence) == 157, f"expected 157 frames, got {len(sequence)}" + first_frame, lidar_data = sequence.load(0, 0) + assert len(first_frame.boxes) == 23, f"expected 23 boxes, got {len(first_frame.boxes)}" + + sequence = av2_box_sequence_loader.load_sequence("02a00399-3857-444e-8db3-a8f58489c394") + assert len(sequence) == 159, f"expected 159 frames, got {len(sequence)}" + first_frame, lidar_data = sequence.load(0, 0) + assert len(first_frame.boxes) == 10, f"expected 10 boxes, got {len(first_frame.boxes)}" diff --git a/tests/datasets/argoverse2/av2_small_tests.py b/tests/datasets/argoverse2/av2_small_tests.py index edebf6e..0b1b75d 100644 --- a/tests/datasets/argoverse2/av2_small_tests.py +++ b/tests/datasets/argoverse2/av2_small_tests.py @@ -8,6 +8,7 @@ Argoverse2NonCausalSceneFlow, ) from bucketed_scene_flow_eval.datasets.argoverse2 import ( + ArgoverseBoxAnnotationSequenceLoader, ArgoverseSceneFlowSequenceLoader, ) from bucketed_scene_flow_eval.datastructures import ( diff --git a/tests/setup.sh b/tests/setup.sh index bc08363..d02a522 100755 --- a/tests/setup.sh +++ b/tests/setup.sh @@ -51,9 +51,9 @@ wget -q https://github.com/kylevedder/BucketedSceneFlowEval/files/13924555/waymo unzip -q /tmp/waymo_open_processed_flow_tiny.zip -d /tmp/ -# Prepare /tmp/nuscenes v1.0-mini -rm -rf /tmp/nuscenes -mkdir -p /tmp/nuscenes -echo "Downloading nuscenes v1.0-mini" -wget -q https://www.nuscenes.org/data/v1.0-mini.tgz -O /tmp/nuscenes/nuscenes_v1.0-mini.tgz -tar -xzf /tmp/nuscenes/nuscenes_v1.0-mini.tgz -C /tmp/nuscenes +# # Prepare /tmp/nuscenes v1.0-mini +# rm -rf /tmp/nuscenes +# mkdir -p /tmp/nuscenes +# echo "Downloading nuscenes v1.0-mini" +# wget -q https://www.nuscenes.org/data/v1.0-mini.tgz -O /tmp/nuscenes/nuscenes_v1.0-mini.tgz +# tar -xzf /tmp/nuscenes/nuscenes_v1.0-mini.tgz -C /tmp/nuscenes From 18912a4fc7833632f85363ec591ad30358046e86 Mon Sep 17 00:00:00 2001 From: Kyle Vedder Date: Thu, 25 Jul 2024 20:51:58 -0400 Subject: [PATCH 2/2] Added basic tests for box loading --- .../argoverse2/argoverse_scene_flow.py | 1 + .../datasets/argoverse2/dataset.py | 8 +- flow_lab/flow_lab.py | 416 ++++++++++++++++++ .../o3d_raw_vis_demo.py | 0 tests/integration_tests.py | 30 +- 5 files changed, 451 insertions(+), 4 deletions(-) create mode 100644 flow_lab/flow_lab.py rename o3d_raw_vis_demo.py => flow_lab/o3d_raw_vis_demo.py (100%) diff --git a/bucketed_scene_flow_eval/datasets/argoverse2/argoverse_scene_flow.py b/bucketed_scene_flow_eval/datasets/argoverse2/argoverse_scene_flow.py index 97e1695..438b95c 100644 --- a/bucketed_scene_flow_eval/datasets/argoverse2/argoverse_scene_flow.py +++ b/bucketed_scene_flow_eval/datasets/argoverse2/argoverse_scene_flow.py @@ -264,6 +264,7 @@ def _subset_log(self, log_subset: Optional[list[str]]): self.sequence_id_lst = [ sequence_id for sequence_id in self.sequence_id_lst if sequence_id in log_subset ] + assert len(self.sequence_id_lst) > 0, f"No sequences found in log_subset {log_subset}" def _sanitize_raw_data_path(self, raw_data_path: Union[Path, list[Path]]) -> list[Path]: if isinstance(raw_data_path, str): diff --git a/bucketed_scene_flow_eval/datasets/argoverse2/dataset.py b/bucketed_scene_flow_eval/datasets/argoverse2/dataset.py index 6310561..0198e8f 100644 --- a/bucketed_scene_flow_eval/datasets/argoverse2/dataset.py +++ b/bucketed_scene_flow_eval/datasets/argoverse2/dataset.py @@ -70,11 +70,17 @@ def __init__( flow_data_path: Optional[Union[Path, list[Path]]] = None, eval_type: str = "bucketed_epe", eval_args=dict(), + load_boxes: bool = False, load_flow: bool = True, use_cache=True, **kwargs, ) -> None: - if load_flow: + if load_boxes: + self.sequence_loader = ArgoverseBoxAnnotationSequenceLoader( + root_dir, + **kwargs, + ) + elif load_flow: self.sequence_loader = ArgoverseSceneFlowSequenceLoader( root_dir, use_gt_flow=use_gt_flow, diff --git a/flow_lab/flow_lab.py b/flow_lab/flow_lab.py new file mode 100644 index 0000000..cfd3eac --- /dev/null +++ b/flow_lab/flow_lab.py @@ -0,0 +1,416 @@ +import argparse +from pathlib import Path + +import numpy as np +import open3d as o3d + +from bucketed_scene_flow_eval.datasets import construct_dataset +from bucketed_scene_flow_eval.datastructures import ( + SE3, + BoundingBox, + TimeSyncedSceneFlowBoxFrame, +) +from bucketed_scene_flow_eval.interfaces import AbstractSequence +from bucketed_scene_flow_eval.utils.glfw_key_ids import * + + +def _update_o3d_mesh_pose(mesh: o3d.geometry.TriangleMesh, start_pose: SE3, target_pose: SE3): + global_translation = target_pose.translation - start_pose.translation + global_rotation = target_pose.rotation_matrix @ np.linalg.inv(start_pose.rotation_matrix) + + mesh.translate(global_translation) + mesh.rotate(global_rotation, center=target_pose.translation) + + +class BoxGeometryWithPose: + def __init__(self, base_box: BoundingBox): + self.base_box = base_box + + # O3D doesn't support rendering boxers as wireframes directly, so we create a box and its associated rendered lineset. + self.o3d_triangle_mesh = o3d.geometry.TriangleMesh.create_box( + width=base_box.length, height=base_box.height, depth=base_box.width + ) + self.o3d_wireframe = o3d.geometry.LineSet.create_from_triangle_mesh(self.o3d_triangle_mesh) + self.imit_pose_of_o3d_geomerty(base_box.pose) + + def imit_pose_of_o3d_geomerty(self, pose: SE3): + o3d_geom_centering_translation = -np.array( + [0.5 * self.base_box.length, 0.5 * self.base_box.height, 0.5 * self.base_box.width] + ) + + center_offset_se3 = SE3.identity().translate(o3d_geom_centering_translation) + o3d_target_pose = pose.compose(center_offset_se3) + + _update_o3d_mesh_pose(self.o3d_wireframe, SE3.identity(), o3d_target_pose) + _update_o3d_mesh_pose(self.o3d_triangle_mesh, SE3.identity(), o3d_target_pose) + + def compute_global_pose( + self, + forward: float = 0, + left: float = 0, + up: float = 0, + pitch: float = 0, + yaw: float = 0, + roll: float = 0, + ) -> SE3: + local_frame_offset_se3 = SE3.from_rot_x_y_z_translation_x_y_z( + roll, pitch, yaw, forward, left, up + ) + + return self.base_box.pose.compose(local_frame_offset_se3) + + def update_from_global(self, global_se3: SE3): + _update_o3d_mesh_pose(self.o3d_wireframe, self.base_box.pose, global_se3) + _update_o3d_mesh_pose(self.o3d_triangle_mesh, self.base_box.pose, global_se3) + self.base_box.pose = global_se3 + + def triangle_mesh_o3d(self) -> o3d.geometry.TriangleMesh: + return self.o3d_triangle_mesh + + def wireframe_o3d(self) -> o3d.geometry.LineSet: + return self.o3d_wireframe + + @property + def pose(self) -> SE3: + return self.base_box.pose + + +def ray_triangle_intersect(ray_origin, ray_direction, v0, v1, v2) -> tuple[bool, np.ndarray | None]: + epsilon = 1e-8 + edge1 = v1 - v0 + edge2 = v2 - v0 + h = np.cross(ray_direction, edge2) + a = np.dot(edge1, h) + if -epsilon < a < epsilon: + return False, None # This ray is parallel to this triangle. + f = 1.0 / a + s = ray_origin - v0 + u = f * np.dot(s, h) + if not (0.0 <= u <= 1.0): + return False, None + q = np.cross(s, edge1) + v = f * np.dot(ray_direction, q) + if not (0.0 <= v <= 1.0): + return False, None + if u + v > 1.0: + return False, None + t = f * np.dot(edge2, q) + if t > epsilon: + intersect_point = ray_origin + ray_direction * t + return True, intersect_point + else: + return ( + False, + None, + ) # This means that there is a line intersection but not a ray intersection. + + +class ViewStateManager: + def __init__(self) -> None: + self.prior_mouse_position: tuple[float, float] | None = None + self.is_view_rotating = False + self.is_translating = False + self.pixel_to_rotate_scale_factor = 1 + self.pixel_to_translate_scale_factor = 1 + self.clickable_geometries: dict[str, BoxGeometryWithPose] = {} + self.selection_axes: o3d.geometry.TriangleMesh | None = None + self.selected_mesh_id: str | None = None + + def add_clickable_geometry(self, id: str, box_geometry: BoxGeometryWithPose): + self.clickable_geometries[id] = box_geometry + + def _update_selection( + self, + vis, + forward: float = 0, + left: float = 0, + up: float = 0, + pitch: float = 0, + yaw: float = 0, + roll: float = 0, + ): + assert self.selected_mesh_id is not None + assert self.selection_axes is not None + selected_mesh = self.clickable_geometries[self.selected_mesh_id] + global_target_se3 = selected_mesh.compute_global_pose( + forward=forward, left=left, up=up, pitch=pitch, yaw=yaw, roll=roll + ) + # for g in global_target_se3.to_o3d(): + # vis.add_geometry(g, reset_bounding_box=False) + + _update_o3d_mesh_pose(self.selection_axes, selected_mesh.pose, global_target_se3) + selected_mesh.update_from_global(global_target_se3) + + vis.update_geometry(selected_mesh.wireframe_o3d()) + vis.update_geometry(self.selection_axes) + + def forward_press(self, vis): + if self.selected_mesh_id is None: + return + self._update_selection(vis, forward=0.1) + + def backward_press(self, vis): + if self.selected_mesh_id is None: + return + self._update_selection(vis, forward=-0.1) + + def left_press(self, vis): + if self.selected_mesh_id is None: + return + self._update_selection(vis, left=0.1) + + def right_press(self, vis): + if self.selected_mesh_id is None: + return + self._update_selection(vis, left=-0.1) + + def up_press(self, vis): + if self.selected_mesh_id is None: + return + self._update_selection(vis, up=0.1) + + def down_press(self, vis): + if self.selected_mesh_id is None: + return + self._update_selection(vis, up=-0.1) + + def yaw_clockwise_press(self, vis): + if self.selected_mesh_id is None: + return + self._update_selection(vis, yaw=0.1) + + def yaw_counterclockwise_press(self, vis): + if self.selected_mesh_id is None: + return + self._update_selection(vis, yaw=-0.1) + + def pitch_up_press(self, vis): + if self.selected_mesh_id is None: + return + self._update_selection(vis, pitch=0.1) + + def pitch_down_press(self, vis): + if self.selected_mesh_id is None: + return + self._update_selection(vis, pitch=-0.1) + + def roll_clockwise_press(self, vis): + if self.selected_mesh_id is None: + return + self._update_selection(vis, roll=0.1) + + def roll_counterclockwise_press(self, vis): + if self.selected_mesh_id is None: + return + self._update_selection(vis, roll=-0.1) + + def on_mouse_move(self, vis, x, y): + if self.prior_mouse_position is not None: + dx = x - self.prior_mouse_position[0] + dy = y - self.prior_mouse_position[1] + view_control = vis.get_view_control() + if self.is_view_rotating: + view_control.rotate( + dx * self.pixel_to_rotate_scale_factor, dy * self.pixel_to_rotate_scale_factor + ) + elif self.is_translating: + view_control.translate( + dx * self.pixel_to_translate_scale_factor, + dy * self.pixel_to_translate_scale_factor, + ) + + self.prior_mouse_position = (x, y) + + def on_mouse_scroll(self, vis, x, y): + view_control = vis.get_view_control() + view_control.scale(y) + + def on_mouse_button(self, vis, button, action, mods): + buttons = ["left", "right", "middle"] + actions = ["up", "down"] + mods_name = ["shift", "ctrl", "alt", "cmd"] + + button = buttons[button] + action = actions[action] + mods = [mods_name[i] for i in range(4) if mods & (1 << i)] + + if button == "left" and action == "down": + self.is_view_rotating = True + elif button == "left" and action == "up": + self.is_view_rotating = False + elif button == "middle" and action == "down": + self.is_translating = True + elif button == "middle" and action == "up": + self.is_translating = False + elif button == "right" and action == "down": + self.pick_mesh(vis, self.prior_mouse_position[0], self.prior_mouse_position[1]) + + print(f"on_mouse_button: {button}, {action}, {mods}") + if button == "right" and action == "down": + self.pick_mesh(vis, self.prior_mouse_position[0], self.prior_mouse_position[1]) + + def select_mesh(self, vis, mesh_id: str): + self.selected_mesh_id = mesh_id + if self.selection_axes is not None: + vis.remove_geometry(self.selection_axes, reset_bounding_box=False) + self.selection_axes = o3d.geometry.TriangleMesh.create_coordinate_frame(size=2) + # o3d.geometry.TriangleMesh.create_sphere(radius=1) + + selected_box_with_pose = self.clickable_geometries[mesh_id] + center = selected_box_with_pose.pose.translation + rotation_matrix = selected_box_with_pose.pose.rotation_matrix + # Use the oriented bounding box center as the origin of the axes + self.selection_axes.translate(center, relative=False) + # Use the oriented bounding box rotation as the rotation of the axes + self.selection_axes.rotate(rotation_matrix) + vis.add_geometry(self.selection_axes, reset_bounding_box=False) + + def deselect_mesh(self, vis): + self.selected_mesh_id = None + + if self.selection_axes is not None: + vis.remove_geometry(self.selection_axes, reset_bounding_box=False) + self.selection_axes = None + + def pick_mesh(self, vis, x, y, visualize_click: bool = False): + view_control = vis.get_view_control() + camera_params = view_control.convert_to_pinhole_camera_parameters() + intrinsic = camera_params.intrinsic.intrinsic_matrix + extrinsic = camera_params.extrinsic + + # Create a ray in camera space + ray_camera = np.array( + [ + (x - intrinsic[0, 2]) / intrinsic[0, 0], + (y - intrinsic[1, 2]) / intrinsic[1, 1], + 1.0, + ] + ) + + # Normalize the ray direction + ray_camera = ray_camera / np.linalg.norm(ray_camera) + + # Convert the ray to world space + rotation = extrinsic[:3, :3] + translation = extrinsic[:3, 3] + + ray_world = np.dot(rotation.T, ray_camera) + ray_dir = ray_world / np.linalg.norm(ray_world) + + camera_pos = -np.dot(rotation.T, translation) + + if visualize_click: + # Add sphere at camera position + sphere = o3d.geometry.TriangleMesh.create_sphere(radius=0.1) + sphere.translate(camera_pos) + vis.add_geometry(sphere, reset_bounding_box=False) + + # Draw the ray in world space + ray_end = camera_pos + ray_dir * 100 # Extend the ray 100 units + ray_line = o3d.geometry.LineSet() + ray_line.points = o3d.utility.Vector3dVector([camera_pos, ray_end]) + ray_line.lines = o3d.utility.Vector2iVector([[0, 1]]) + ray_line.colors = o3d.utility.Vector3dVector([[1, 0, 0]]) + vis.add_geometry(ray_line, reset_bounding_box=False) + + closest_mesh_lookup: dict[str, float] = {} + for id, box_with_pose in self.clickable_geometries.items(): + mesh = box_with_pose.triangle_mesh_o3d() + vertices = np.asarray(mesh.vertices) + triangles = np.asarray(mesh.triangles) + for tri in triangles: + v0, v1, v2 = vertices[tri] + hit, intersect_point = ray_triangle_intersect(camera_pos, ray_dir, v0, v1, v2) + if hit: + intersection_distance = np.linalg.norm(intersect_point - camera_pos) + closest_mesh_lookup[id] = min( + intersection_distance, closest_mesh_lookup.get(id, np.inf) + ) + + if len(closest_mesh_lookup) == 0: + self.deselect_mesh(vis) + return + + closest_mesh_id = min(closest_mesh_lookup, key=closest_mesh_lookup.get) + print(f"Selected mesh: {closest_mesh_id}") + self.selected_mesh_id = closest_mesh_id + self.select_mesh(vis, closest_mesh_id) + + +def load_box_frames() -> list[TimeSyncedSceneFlowBoxFrame]: + parser = argparse.ArgumentParser() + parser.add_argument("--dataset_name", type=str, default="Argoverse2NonCausalSceneFlow") + parser.add_argument("--root_dir", type=Path, required=True) + parser.add_argument("--sequence_length", type=int, required=True) + parser.add_argument("--sequence_id", type=str, required=True) + args = parser.parse_args() + + sequence_length = args.sequence_length + log_subset = [args.sequence_id] + + dataset = construct_dataset( + name=args.dataset_name, + args=dict( + root_dir=args.root_dir, + subsequence_length=sequence_length, + with_ground=False, + range_crop_type="ego", + load_boxes=True, + log_subset=log_subset, + ), + ) + assert len(dataset) == 1, f"Expected 1 sequence, got {len(dataset)}" + + return dataset[0] + + +def main(): + + frames = load_box_frames() + + frame = frames[0] + + vis = o3d.visualization.VisualizerWithKeyCallback() + + state_manager = ViewStateManager() + + for idx, box in enumerate(frame.boxes): + state_manager.add_clickable_geometry(f"box{idx:06d}", BoxGeometryWithPose(box)) + + vis.register_mouse_move_callback(state_manager.on_mouse_move) + vis.register_mouse_scroll_callback(state_manager.on_mouse_scroll) + vis.register_mouse_button_callback(state_manager.on_mouse_button) + + # fmt: off + vis.register_key_callback(ord("W"), state_manager.forward_press) + vis.register_key_callback(ord("S"), state_manager.backward_press) + vis.register_key_callback(ord("A"), state_manager.left_press) + vis.register_key_callback(ord("D"), state_manager.right_press) + vis.register_key_callback(ord("Z"), state_manager.down_press) + vis.register_key_callback(ord("X"), state_manager.up_press) + vis.register_key_callback(ord("Q"), state_manager.yaw_clockwise_press) + vis.register_key_callback(ord("E"), state_manager.yaw_counterclockwise_press) + # Use arrow keys for pitch and roll + vis.register_key_callback(GLFW_KEY_UP, state_manager.pitch_up_press) + vis.register_key_callback(GLFW_KEY_DOWN, state_manager.pitch_down_press) + vis.register_key_callback(GLFW_KEY_RIGHT, state_manager.roll_clockwise_press) + vis.register_key_callback(GLFW_KEY_LEFT, state_manager.roll_counterclockwise_press) + # fmt: on + + vis.create_window() + vis.add_geometry(frame.pc.ego_pc.to_o3d().paint_uniform_color([0.5, 0.5, 0.5])) + for box in state_manager.clickable_geometries.values(): + vis.add_geometry(box.wireframe_o3d()) + # Add a coordinate frame at the origin + vis.add_geometry(o3d.geometry.TriangleMesh.create_coordinate_frame(size=5)) + render_option = vis.get_render_option() + # render_option.mesh_show_wireframe = True + # render_option.light_on = False + # render_option.mesh_shade_option = o3d.visualization.MeshShadeOption.Default + vis.run() + + +if __name__ == "__main__": + + print("Customized visualization with mouse action.") + main() diff --git a/o3d_raw_vis_demo.py b/flow_lab/o3d_raw_vis_demo.py similarity index 100% rename from o3d_raw_vis_demo.py rename to flow_lab/o3d_raw_vis_demo.py diff --git a/tests/integration_tests.py b/tests/integration_tests.py index 9870c9d..67e2b1e 100644 --- a/tests/integration_tests.py +++ b/tests/integration_tests.py @@ -95,22 +95,42 @@ def argo_dataset_test_no_flow_with_ground(): ) +@pytest.fixture +def argo_box_dataset(): + return construct_dataset( + "argoverse2noncausalsceneflow", + dict( + root_dir="/tmp/argoverse2_small/val", + subsequence_length=150, + load_boxes=True, + range_crop_type="ego", + log_subset=["02678d04-cc9f-3148-9f95-1ba66347dff9"], + ), + ) + + def _validate_dataloader( dataloader: AbstractDataset, full_pc_size: int, masked_pc_size: int, expected_len: int = 1, + expected_num_frames: int = 2, ): assert len(dataloader) == expected_len, f"Expected {expected_len} scene, got {len(dataloader)}" + dataloader_entries = dataloader[0] + assert ( + len(dataloader_entries) == expected_num_frames + ), f"Expected list of length {expected_num_frames}, got {len(dataloader_entries)}" # Failure of the following line indicates that the __getitem__ method is broken. - _, _ = dataloader[0] num_iteration_entries = 0 for entry in dataloader: assert isinstance(entry, list), f"Expected list, got {type(entry)}" - assert len(entry) == 2, f"Expected list of length 2, got {len(entry)}" - item_t1, _ = entry + assert ( + len(entry) == expected_num_frames + ), f"Expected list of length {expected_num_frames}, got {len(entry)}" + item_t1 = entry[0] assert ( full_pc_size == item_t1.pc.full_pc.shape[0] @@ -128,6 +148,10 @@ def _validate_dataloader( ), f"Expected {expected_len} iteration, got {num_iteration_entries}" +def test_argo_box_dataset(argo_box_dataset): + _validate_dataloader(argo_box_dataset, 95381, 69600, 1, 150) + + def test_waymo_dataset(waymo_dataset_gt): _validate_dataloader(waymo_dataset_gt, 124364, 124364)