Skip to content

Commit

Permalink
Added multi-frame rendering and frame navigation
Browse files Browse the repository at this point in the history
  • Loading branch information
lisiyi777 committed Jul 30, 2024
1 parent a350a10 commit d8280d5
Showing 1 changed file with 141 additions and 27 deletions.
168 changes: 141 additions & 27 deletions flow_lab/flow_lab.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,27 @@ def _update_o3d_mesh_pose(mesh: o3d.geometry.TriangleMesh, start_pose: SE3, targ


class BoxGeometryWithPose:
def __init__(self, base_box: BoundingBox):
def __init__(self, base_box: BoundingBox, color=[0.1, 0.1, 0.1]):
self.base_box = base_box
self.color = color

# O3D doesn't support rendering boxers as wireframes directly, so we create a box and its associated rendered lineset.
self.o3d_triangle_mesh = o3d.geometry.TriangleMesh.create_box(
width=base_box.length, height=base_box.height, depth=base_box.width
)
self.o3d_wireframe = o3d.geometry.LineSet.create_from_triangle_mesh(self.o3d_triangle_mesh)
self.imit_pose_of_o3d_geomerty(base_box.pose)
self.set_color(self.color)

def set_color(self, color):
"""
Sets the color of the geomerty
Args:
color: A list or array of three floats representing the RGB color.
"""
num_lines = len(self.o3d_wireframe.lines)
self.o3d_wireframe.colors = o3d.utility.Vector3dVector([color] * num_lines)

def imit_pose_of_o3d_geomerty(self, pose: SE3):
o3d_geom_centering_translation = -np.array(
Expand Down Expand Up @@ -106,7 +118,7 @@ def ray_triangle_intersect(ray_origin, ray_direction, v0, v1, v2) -> tuple[bool,


class ViewStateManager:
def __init__(self) -> None:
def __init__(self, frames, cropping_range, rolling_window_size) -> None:
self.prior_mouse_position: tuple[float, float] | None = None
self.is_view_rotating = False
self.is_translating = False
Expand All @@ -115,6 +127,12 @@ def __init__(self) -> None:
self.clickable_geometries: dict[str, BoxGeometryWithPose] = {}
self.selection_axes: o3d.geometry.TriangleMesh | None = None
self.selected_mesh_id: str | None = None
self.current_frame_index = 0
self.tuning_scale = 0.1
self.frames = frames
self.cropping_range = cropping_range
self.rolling_window_size = rolling_window_size
self.trajectory_geometries = [] # Track geometries for the selected mesh trajectory

def add_clickable_geometry(self, id: str, box_geometry: BoxGeometryWithPose):
self.clickable_geometries[id] = box_geometry
Expand Down Expand Up @@ -147,62 +165,82 @@ def _update_selection(
def forward_press(self, vis):
if self.selected_mesh_id is None:
return
self._update_selection(vis, forward=0.1)
self._update_selection(vis, forward=self.tuning_scale)

def backward_press(self, vis):
if self.selected_mesh_id is None:
return
self._update_selection(vis, forward=-0.1)
self._update_selection(vis, forward=-self.tuning_scale)

def left_press(self, vis):
if self.selected_mesh_id is None:
return
self._update_selection(vis, left=0.1)
self._update_selection(vis, left=self.tuning_scale)

def right_press(self, vis):
if self.selected_mesh_id is None:
return
self._update_selection(vis, left=-0.1)
self._update_selection(vis, left=-self.tuning_scale)

def up_press(self, vis):
if self.selected_mesh_id is None:
return
self._update_selection(vis, up=0.1)
self._update_selection(vis, up=self.tuning_scale)

def down_press(self, vis):
if self.selected_mesh_id is None:
return
self._update_selection(vis, up=-0.1)
self._update_selection(vis, up=-self.tuning_scale)

def yaw_clockwise_press(self, vis):
if self.selected_mesh_id is None:
return
self._update_selection(vis, yaw=0.1)
self._update_selection(vis, yaw=self.tuning_scale)

def yaw_counterclockwise_press(self, vis):
if self.selected_mesh_id is None:
return
self._update_selection(vis, yaw=-0.1)
self._update_selection(vis, yaw=-self.tuning_scale)

def pitch_up_press(self, vis):
if self.selected_mesh_id is None:
return
self._update_selection(vis, pitch=0.1)
self._update_selection(vis, pitch=self.tuning_scale)

def pitch_down_press(self, vis):
if self.selected_mesh_id is None:
return
self._update_selection(vis, pitch=-0.1)
self._update_selection(vis, pitch=-self.tuning_scale)

def roll_clockwise_press(self, vis):
if self.selected_mesh_id is None:
return
self._update_selection(vis, roll=0.1)
self._update_selection(vis, roll=self.tuning_scale)

def roll_counterclockwise_press(self, vis):
if self.selected_mesh_id is None:
return
self._update_selection(vis, roll=-0.1)
self._update_selection(vis, roll=-self.tuning_scale)

def forward_frame_press(self, vis):
self.current_frame_index = self.current_frame_index + 1
self.current_frame_index = min(len(self.frames) - 1, self.current_frame_index)
self.render_pc_and_boxes(vis)
print(f"At frame: {self.current_frame_index}")

def backward_frame_press(self, vis):
self.current_frame_index = self.current_frame_index - 1
self.current_frame_index = max(0, self.current_frame_index)
self.render_pc_and_boxes(vis)
print(f"You are at frame: {self.current_frame_index}")

def shift_actions(self, vis, action, mods):
actions = ["up", "down"]
action = actions[action]
if action == "down":
self.tuning_scale = 0.02
elif action == "up":
self.tuning_scale = 0.1

def on_mouse_move(self, vis, x, y):
if self.prior_mouse_position is not None:
Expand Down Expand Up @@ -264,13 +302,15 @@ def select_mesh(self, vis, mesh_id: str):
# Use the oriented bounding box rotation as the rotation of the axes
self.selection_axes.rotate(rotation_matrix)
vis.add_geometry(self.selection_axes, reset_bounding_box=False)
self.render_selected_mesh_trajectory(vis)

def deselect_mesh(self, vis):
self.selected_mesh_id = None

if self.selection_axes is not None:
vis.remove_geometry(self.selection_axes, reset_bounding_box=False)
self.selection_axes = None
self.clear_trajectory_geometries(vis)

def pick_mesh(self, vis, x, y, visualize_click: bool = False):
view_control = vis.get_view_control()
Expand Down Expand Up @@ -336,6 +376,78 @@ def pick_mesh(self, vis, x, y, visualize_click: bool = False):
self.selected_mesh_id = closest_mesh_id
self.select_mesh(vis, closest_mesh_id)

def rolling_window_range(self):
"""
Calculate the range of frames to display
"""
half_window_size = self.rolling_window_size // 2
adjustment = 0 if self.rolling_window_size % 2 else 1
start_index = max(0, self.current_frame_index - half_window_size - adjustment)
end_index = min(len(self.frames), self.current_frame_index + half_window_size + 1)

return start_index, end_index

def clear_trajectory_geometries(self, vis):
"""
Removes the trajectory geometries from the visualizer
"""
for geometry in self.trajectory_geometries:
vis.remove_geometry(geometry, reset_bounding_box=False)
self.trajectory_geometries.clear()

def render_selected_mesh_trajectory(self, vis):
"""
Adds the trajectory geometries for the selected mesh
"""
start_index, end_index = self.rolling_window_range()
self.clear_trajectory_geometries(vis)

if self.selected_mesh_id is not None:
selected_box = self.clickable_geometries[self.selected_mesh_id].base_box
selected_box_uuid = selected_box.track_uuid

for i in range(start_index, end_index):
if i == self.current_frame_index:
continue

# Green color for other frames
frame = self.frames[i]
for box in frame.boxes:
if box.track_uuid == selected_box_uuid:
box_geom = BoxGeometryWithPose(box, color=[0, 0.5, 0])
wireframe = box_geom.wireframe_o3d()
vis.add_geometry(wireframe, reset_bounding_box=False)
self.trajectory_geometries.append(wireframe)

def render_pc_and_boxes(self, vis):
"""
Renders the point clouds and bounding boxes for the given frames
"""
current_frame_index = self.current_frame_index
start_index, end_index = self.rolling_window_range()

vis.clear_geometries()
# Loop over the frames and display pointclouds
for i in range(start_index, end_index):
frame = self.frames[i]
ego_translation = frame.pc.global_pose.translation
pc_color = [1, 0, 0] if i == current_frame_index else [0.75, 0.75, 0.75]
vis.add_geometry(frame.pc.ego_pc.to_o3d().paint_uniform_color(pc_color))

# Render bounding boxes for the current frame only
frame = self.frames[current_frame_index]
for idx, box in enumerate(frame.boxes):
box_translation = box.pose.transform_matrix[:3, 3]
if np.linalg.norm(ego_translation - box_translation) < self.cropping_range:
self.add_clickable_geometry(f"box{idx:06d}", BoxGeometryWithPose(box))
vis.add_geometry(
self.clickable_geometries[f"box{idx:06d}"].wireframe_o3d(),
# reset_bounding_box=False,
)

# Add a coordinate frame at the origin
vis.add_geometry(o3d.geometry.TriangleMesh.create_coordinate_frame(size=5))


def load_box_frames() -> list[TimeSyncedSceneFlowBoxFrame]:
parser = argparse.ArgumentParser()
Expand Down Expand Up @@ -364,24 +476,24 @@ def load_box_frames() -> list[TimeSyncedSceneFlowBoxFrame]:
return dataset[0]


def main():
def main(cropping_range=10000.0, rolling_window_size=5):
"""
Args:
rolling_window_size: The number of frames to display around the current frame
cropping_range: The maximum distance from the ego vehicle within which to display bounding boxes
"""

frames = load_box_frames()

frame = frames[0]

vis = o3d.visualization.VisualizerWithKeyCallback()

state_manager = ViewStateManager()

for idx, box in enumerate(frame.boxes):
state_manager.add_clickable_geometry(f"box{idx:06d}", BoxGeometryWithPose(box))
state_manager = ViewStateManager(frames, cropping_range, rolling_window_size)

vis.register_mouse_move_callback(state_manager.on_mouse_move)
vis.register_mouse_scroll_callback(state_manager.on_mouse_scroll)
vis.register_mouse_button_callback(state_manager.on_mouse_button)

# fmt: off
# Use WASD keys for translation, Q and E keys for yaw
vis.register_key_callback(ord("W"), state_manager.forward_press)
vis.register_key_callback(ord("S"), state_manager.backward_press)
vis.register_key_callback(ord("A"), state_manager.left_press)
Expand All @@ -395,14 +507,16 @@ def main():
vis.register_key_callback(GLFW_KEY_DOWN, state_manager.pitch_down_press)
vis.register_key_callback(GLFW_KEY_RIGHT, state_manager.roll_clockwise_press)
vis.register_key_callback(GLFW_KEY_LEFT, state_manager.roll_counterclockwise_press)
# Use , and . keys for going forward and backward through the frames
vis.register_key_callback(ord(","), lambda vis: state_manager.backward_frame_press(vis))
vis.register_key_callback(ord("."), lambda vis: state_manager.forward_frame_press(vis))
# Press Shift to fine tune
vis.register_key_action_callback(GLFW_KEY_LEFT_SHIFT , state_manager.shift_actions)
# fmt: on

vis.create_window()
vis.add_geometry(frame.pc.ego_pc.to_o3d().paint_uniform_color([0.5, 0.5, 0.5]))
for box in state_manager.clickable_geometries.values():
vis.add_geometry(box.wireframe_o3d())
# Add a coordinate frame at the origin
vis.add_geometry(o3d.geometry.TriangleMesh.create_coordinate_frame(size=5))
state_manager.render_pc_and_boxes(vis)

render_option = vis.get_render_option()
# render_option.mesh_show_wireframe = True
# render_option.light_on = False
Expand Down

0 comments on commit d8280d5

Please sign in to comment.