diff --git a/flow_lab/flow_lab.py b/flow_lab/flow_lab.py index 70821f4..39123fa 100644 --- a/flow_lab/flow_lab.py +++ b/flow_lab/flow_lab.py @@ -16,31 +16,6 @@ from bucketed_scene_flow_eval.utils.glfw_key_ids import * -def setup_save_path(root_dir: Path, sequence_id: str, save_dir, preprocess) -> Path: - """ - Sets up the path to the directory where processed data should be saved. - """ - if save_dir: - # Check if the provided save_dir exists - if not save_dir.exists(): - raise ValueError(f"The provided save_dir '{save_dir}' does not exist.") - sequence_save_dir = save_dir / sequence_id - sequence_save_dir.mkdir(exist_ok=True) - else: - if preprocess: - # Get the parent directory of root_dir and create a new directory called root_dir_processed - parent_dir = root_dir.parent - processed_dir = parent_dir / f"{root_dir.name}_processed" - processed_dir.mkdir(exist_ok=True) - sequence_save_dir = processed_dir / sequence_id - sequence_save_dir.mkdir(exist_ok=True) - else: - sequence_save_dir = root_dir / sequence_id - sequence_save_dir.mkdir(exist_ok=True) - - return sequence_save_dir - - def parse_arguments(): # Define the default path for the lookup table inside the flow_lab folder flow_lab_dir = Path(__file__).resolve().parent @@ -69,12 +44,12 @@ def parse_arguments(): default=default_lookup_table_path, help="Path to JSON lookup table for sequence lengths.", ) - parser.add_argument( - "--save_dir", - type=Path, - required=False, - help="Directory where processed data will be saved. ", - ) + # parser.add_argument( + # "--save_dir", + # type=Path, + # required=False, + # help="Directory where processed data will be saved. ", + # ) parser.add_argument( "--preprocess", action="store_true", @@ -88,7 +63,6 @@ def parse_arguments(): ) args = parser.parse_args() - args.save_dir = setup_save_path(args.root_dir, args.sequence_id, args.save_dir, args.preprocess) return args @@ -174,6 +148,7 @@ def setup_visualizer(state_manager, annotation_saver, frames): vis.register_key_action_callback( GLFW_KEY_SPACE, lambda vis, action, mods: state_manager.toggle_box(vis, action, mods) ) + vis.register_key_callback(ord("V"), state_manager.toggle_propagate_with_velocity) vis.create_window() # render_option = vis.get_render_option() @@ -189,6 +164,28 @@ def load_sequence_length(sequence_id: str, lookup_table: Path) -> int: return data.get(sequence_id, 0) # Default to 0 if the sequence_id is not found +def setup_save_path(root_dir: Path, sequence_id: str) -> Path: + """ + Set up the path to the folder where processed annotations should be saved. + """ + parent_dir = root_dir.parent + processed_dir = parent_dir / f"{root_dir.name}_processed" + sequence_save_dir = processed_dir / "val" / sequence_id + + return sequence_save_dir + + +def setup_load_path(root_dir: Path, sequence_id: str) -> Path: + """ + Set up the path to load data when preprocess is required. + """ + parent_dir = root_dir.parent + processed_dir = parent_dir / f"{root_dir.name}_processed" + sequence_save_dir = processed_dir / "val" + + return sequence_save_dir + + def main(): args = parse_arguments() @@ -197,22 +194,27 @@ def main(): if sequence_length == 0: raise ValueError(f"Sequence ID {args.sequence_id} not found in lookup table.") - # load frames - frames = load_box_frames(args.root_dir, args.dataset_name, sequence_length, args.sequence_id) - # this is used to save edited data - annotation_saver = AnnotationSaver(args.save_dir) + output_path = setup_save_path(args.root_dir, args.sequence_id) + annotation_saver = AnnotationSaver(output_path) - if args.preprocess: + # Check if preprocessing is required and load frames + annotation_file_path = output_path / "annotations.feather" + if args.preprocess or not annotation_file_path.exists(): print("Preprocessing data...") + input_path = args.root_dir / "val" + frames = load_box_frames(input_path, args.dataset_name, sequence_length, args.sequence_id) frames = preprocess_box_frames(frames) annotation_saver.save(frames) else: - # declare the view state manager and display the window - state_manager = ViewStateManager(frames, annotation_saver, args.rolling_window_size) - vis = setup_visualizer(state_manager, annotation_saver, frames) - state_manager.render_pc_and_boxes(vis, reset_bounding_box=True) - vis.run() + input_path = setup_load_path(args.root_dir, args.sequence_id) + frames = load_box_frames(input_path, args.dataset_name, sequence_length, args.sequence_id) + + # declare the view state manager and display the window + state_manager = ViewStateManager(frames, annotation_saver, args.rolling_window_size) + vis = setup_visualizer(state_manager, annotation_saver, frames) + state_manager.render_pc_and_boxes(vis, reset_bounding_box=True) + vis.run() if __name__ == "__main__": diff --git a/flow_lab/vis_classes.py b/flow_lab/vis_classes.py index d5234fe..1e53aa7 100644 --- a/flow_lab/vis_classes.py +++ b/flow_lab/vis_classes.py @@ -135,8 +135,12 @@ def __init__( self.frames = frames self.rolling_window_size = rolling_window_size self.trajectory_geometries: list[RenderableBox] = [] + # the two blow is used for zoom self.is_zoomed = False self.original_view = None + # Used for velocity + self.propagate_with_velocity = False + self.velocities = {} # Used for toggle box self.current_box_index = -1 @@ -243,13 +247,16 @@ def roll_counterclockwise_press(self, vis): print("roll_counterclockwise_press") def forward_frame_press(self, vis): - self.annotation_saver.save(self.frames) - self.current_frame_index = self.current_frame_index + 1 - self.current_frame_index = min(len(self.frames) - 1, self.current_frame_index) - self.render_pc_and_boxes(vis) + # self.annotation_saver.save(self.frames) + if self.current_frame_index < len(self.frames) - 1: + self.current_frame_index += 1 + self.render_pc_and_boxes(vis) + if self.propagate_with_velocity: + self.apply_velocity() + self.render_pc_and_boxes(vis) def backward_frame_press(self, vis): - self.annotation_saver.save(self.frames) + # self.annotation_saver.save(self.frames) self.current_frame_index = self.current_frame_index - 1 self.current_frame_index = max(0, self.current_frame_index) self.render_pc_and_boxes(vis) @@ -301,7 +308,7 @@ def on_mouse_scroll(self, vis, x, y): def on_mouse_button(self, vis, button, action, mods): buttons = ["left", "right", "middle"] - actions = ["up", "down"] + actions = ["up", "down", "drag"] mods_name = ["shift", "ctrl", "alt", "cmd"] button = buttons[button] @@ -338,7 +345,7 @@ def select_mesh(self, vis, mesh_id: str): # Use the oriented bounding box rotation as the rotation of the axes self.selection_axes.rotate(rotation_matrix) vis.add_geometry(self.selection_axes, reset_bounding_box=False) - self.render_selected_mesh_trajectory(vis) + # self.render_selected_mesh_trajectory(vis) def deselect_mesh(self, vis): self.selected_mesh_id = None @@ -463,6 +470,7 @@ def render_pc_and_boxes(self, vis, reset_bounding_box: bool = False): start_index, end_index = self.rolling_window_range() vis.clear_geometries() + self.clickable_geometries = {} # Loop over the frames and display pointclouds for i in range(start_index, end_index): frame = self.frames[i] @@ -510,6 +518,9 @@ def get_annotations(self) -> list[dict]: return annotations def zoom_press(self, vis, action, mods): + """ + Callback function for zoom to box. + """ actions = ["release", "press", "repeat"] mods_name = ["shift", "ctrl", "alt", "cmd"] action = actions[action] @@ -539,6 +550,9 @@ def zoom_to_box(self, vis): ctr.set_zoom(0.1) def toggle_box(self, vis, action, mods): + """ + Toggle the selected box. + """ actions = ["release", "press", "repeat"] mods_name = ["shift", "ctrl", "alt", "cmd"] action = actions[action] @@ -569,3 +583,78 @@ def toggle_box(self, vis, action, mods): self.zoom_to_box(vis) # print(f"Toggled to box {new_mesh_id}") + + def toggle_propagate_with_velocity(self, vis): + """ + Toggle the propagate_with_velocity feature on and off. + """ + self.propagate_with_velocity = not self.propagate_with_velocity + print(f"Propagate with velocity: {'On' if self.propagate_with_velocity else 'Off'}") + if self.propagate_with_velocity: + self.compute_velocities() + + def compute_velocities(self): + """ + Compute velocities for all the meshes. + """ + if self.current_frame_index < 2: + return # Need at least two frames to calculate velocity + self.velocities.clear() + + current_frame = self.frames[self.current_frame_index] + prev_frame = self.frames[self.current_frame_index - 1] + prev_prev_frame = self.frames[self.current_frame_index - 2] + + for box, pose in current_frame.boxes.valid_boxes(): + uuid = box.track_uuid + prev_pose = self.find_pose_in_frame(prev_frame, uuid) + prev_prev_pose = self.find_pose_in_frame(prev_prev_frame, uuid) + + if prev_pose and prev_prev_pose: + velocity = self.calculate_velocity(prev_prev_pose, prev_pose) + self.velocities[uuid] = velocity + else: + self.velocities[uuid] = np.zeros(3) # No movement if no corresponding box found + + def find_pose_in_frame(self, frame, track_uuid: str): + """ + Find the pose of the box according to track_uuid + """ + for box, pose in frame.boxes.valid_boxes(): + if box.track_uuid == track_uuid: + return pose + return None + + def calculate_velocity(self, pose1: PoseInfo, pose2: PoseInfo) -> np.ndarray: + """ + Calculate displacement between two poses (suppose time interval=1 and use it as velocity). + """ + translation1 = pose1.sensor_to_global.translation + translation2 = pose2.sensor_to_global.translation + velocity = translation2 - translation1 # Update the pose of the RenderableBox + return velocity + + def apply_velocity(self): + """ + Apply stored velocities to the boxes in the current frame. + """ + self.compute_velocities() + + current_frame = self.frames[self.current_frame_index] + last_frame = self.frames[self.current_frame_index - 1] + + for key, renderable_box in self.clickable_geometries.items(): + # for box, pose in current_frame.boxes.valid_boxes(): + uuid = renderable_box.base_box.track_uuid + if uuid in self.velocities: + last_pose = self.find_pose_in_frame(last_frame, uuid) + if last_pose: + # Update the current pose based on the last pose and velocity + new_translation = last_pose.sensor_to_global.translation + self.velocities[uuid] + # Create a new SE3 object for the new global pose + new_global_pose = SE3( + rotation_matrix=last_pose.sensor_to_global.rotation_matrix, + translation=new_translation, + ) + + renderable_box.update_from_global(new_global_pose)