Skip to content

Commit

Permalink
Improved data handling; added velocity calculation
Browse files Browse the repository at this point in the history
  • Loading branch information
lisiyi777 committed Aug 7, 2024
1 parent d8798c6 commit 2c08eee
Show file tree
Hide file tree
Showing 2 changed files with 140 additions and 49 deletions.
86 changes: 44 additions & 42 deletions flow_lab/flow_lab.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,31 +16,6 @@
from bucketed_scene_flow_eval.utils.glfw_key_ids import *


def setup_save_path(root_dir: Path, sequence_id: str, save_dir, preprocess) -> Path:
"""
Sets up the path to the directory where processed data should be saved.
"""
if save_dir:
# Check if the provided save_dir exists
if not save_dir.exists():
raise ValueError(f"The provided save_dir '{save_dir}' does not exist.")
sequence_save_dir = save_dir / sequence_id
sequence_save_dir.mkdir(exist_ok=True)
else:
if preprocess:
# Get the parent directory of root_dir and create a new directory called root_dir_processed
parent_dir = root_dir.parent
processed_dir = parent_dir / f"{root_dir.name}_processed"
processed_dir.mkdir(exist_ok=True)
sequence_save_dir = processed_dir / sequence_id
sequence_save_dir.mkdir(exist_ok=True)
else:
sequence_save_dir = root_dir / sequence_id
sequence_save_dir.mkdir(exist_ok=True)

return sequence_save_dir


def parse_arguments():
# Define the default path for the lookup table inside the flow_lab folder
flow_lab_dir = Path(__file__).resolve().parent
Expand Down Expand Up @@ -69,12 +44,12 @@ def parse_arguments():
default=default_lookup_table_path,
help="Path to JSON lookup table for sequence lengths.",
)
parser.add_argument(
"--save_dir",
type=Path,
required=False,
help="Directory where processed data will be saved. ",
)
# parser.add_argument(
# "--save_dir",
# type=Path,
# required=False,
# help="Directory where processed data will be saved. ",
# )
parser.add_argument(
"--preprocess",
action="store_true",
Expand All @@ -88,7 +63,6 @@ def parse_arguments():
)

args = parser.parse_args()
args.save_dir = setup_save_path(args.root_dir, args.sequence_id, args.save_dir, args.preprocess)

return args

Expand Down Expand Up @@ -174,6 +148,7 @@ def setup_visualizer(state_manager, annotation_saver, frames):
vis.register_key_action_callback(
GLFW_KEY_SPACE, lambda vis, action, mods: state_manager.toggle_box(vis, action, mods)
)
vis.register_key_callback(ord("V"), state_manager.toggle_propagate_with_velocity)

vis.create_window()
# render_option = vis.get_render_option()
Expand All @@ -189,6 +164,28 @@ def load_sequence_length(sequence_id: str, lookup_table: Path) -> int:
return data.get(sequence_id, 0) # Default to 0 if the sequence_id is not found


def setup_save_path(root_dir: Path, sequence_id: str) -> Path:
"""
Set up the path to the folder where processed annotations should be saved.
"""
parent_dir = root_dir.parent
processed_dir = parent_dir / f"{root_dir.name}_processed"
sequence_save_dir = processed_dir / "val" / sequence_id

return sequence_save_dir


def setup_load_path(root_dir: Path, sequence_id: str) -> Path:
"""
Set up the path to load data when preprocess is required.
"""
parent_dir = root_dir.parent
processed_dir = parent_dir / f"{root_dir.name}_processed"
sequence_save_dir = processed_dir / "val"

return sequence_save_dir


def main():
args = parse_arguments()

Expand All @@ -197,22 +194,27 @@ def main():
if sequence_length == 0:
raise ValueError(f"Sequence ID {args.sequence_id} not found in lookup table.")

# load frames
frames = load_box_frames(args.root_dir, args.dataset_name, sequence_length, args.sequence_id)

# this is used to save edited data
annotation_saver = AnnotationSaver(args.save_dir)
output_path = setup_save_path(args.root_dir, args.sequence_id)
annotation_saver = AnnotationSaver(output_path)

if args.preprocess:
# Check if preprocessing is required and load frames
annotation_file_path = output_path / "annotations.feather"
if args.preprocess or not annotation_file_path.exists():
print("Preprocessing data...")
input_path = args.root_dir / "val"
frames = load_box_frames(input_path, args.dataset_name, sequence_length, args.sequence_id)
frames = preprocess_box_frames(frames)
annotation_saver.save(frames)
else:
# declare the view state manager and display the window
state_manager = ViewStateManager(frames, annotation_saver, args.rolling_window_size)
vis = setup_visualizer(state_manager, annotation_saver, frames)
state_manager.render_pc_and_boxes(vis, reset_bounding_box=True)
vis.run()
input_path = setup_load_path(args.root_dir, args.sequence_id)
frames = load_box_frames(input_path, args.dataset_name, sequence_length, args.sequence_id)

# declare the view state manager and display the window
state_manager = ViewStateManager(frames, annotation_saver, args.rolling_window_size)
vis = setup_visualizer(state_manager, annotation_saver, frames)
state_manager.render_pc_and_boxes(vis, reset_bounding_box=True)
vis.run()


if __name__ == "__main__":
Expand Down
103 changes: 96 additions & 7 deletions flow_lab/vis_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,12 @@ def __init__(
self.frames = frames
self.rolling_window_size = rolling_window_size
self.trajectory_geometries: list[RenderableBox] = []
# the two blow is used for zoom
self.is_zoomed = False
self.original_view = None
# Used for velocity
self.propagate_with_velocity = False
self.velocities = {}
# Used for toggle box
self.current_box_index = -1

Expand Down Expand Up @@ -243,13 +247,16 @@ def roll_counterclockwise_press(self, vis):
print("roll_counterclockwise_press")

def forward_frame_press(self, vis):
self.annotation_saver.save(self.frames)
self.current_frame_index = self.current_frame_index + 1
self.current_frame_index = min(len(self.frames) - 1, self.current_frame_index)
self.render_pc_and_boxes(vis)
# self.annotation_saver.save(self.frames)
if self.current_frame_index < len(self.frames) - 1:
self.current_frame_index += 1
self.render_pc_and_boxes(vis)
if self.propagate_with_velocity:
self.apply_velocity()
self.render_pc_and_boxes(vis)

def backward_frame_press(self, vis):
self.annotation_saver.save(self.frames)
# self.annotation_saver.save(self.frames)
self.current_frame_index = self.current_frame_index - 1
self.current_frame_index = max(0, self.current_frame_index)
self.render_pc_and_boxes(vis)
Expand Down Expand Up @@ -301,7 +308,7 @@ def on_mouse_scroll(self, vis, x, y):

def on_mouse_button(self, vis, button, action, mods):
buttons = ["left", "right", "middle"]
actions = ["up", "down"]
actions = ["up", "down", "drag"]
mods_name = ["shift", "ctrl", "alt", "cmd"]

button = buttons[button]
Expand Down Expand Up @@ -338,7 +345,7 @@ def select_mesh(self, vis, mesh_id: str):
# Use the oriented bounding box rotation as the rotation of the axes
self.selection_axes.rotate(rotation_matrix)
vis.add_geometry(self.selection_axes, reset_bounding_box=False)
self.render_selected_mesh_trajectory(vis)
# self.render_selected_mesh_trajectory(vis)

def deselect_mesh(self, vis):
self.selected_mesh_id = None
Expand Down Expand Up @@ -463,6 +470,7 @@ def render_pc_and_boxes(self, vis, reset_bounding_box: bool = False):
start_index, end_index = self.rolling_window_range()

vis.clear_geometries()
self.clickable_geometries = {}
# Loop over the frames and display pointclouds
for i in range(start_index, end_index):
frame = self.frames[i]
Expand Down Expand Up @@ -510,6 +518,9 @@ def get_annotations(self) -> list[dict]:
return annotations

def zoom_press(self, vis, action, mods):
"""
Callback function for zoom to box.
"""
actions = ["release", "press", "repeat"]
mods_name = ["shift", "ctrl", "alt", "cmd"]
action = actions[action]
Expand Down Expand Up @@ -539,6 +550,9 @@ def zoom_to_box(self, vis):
ctr.set_zoom(0.1)

def toggle_box(self, vis, action, mods):
"""
Toggle the selected box.
"""
actions = ["release", "press", "repeat"]
mods_name = ["shift", "ctrl", "alt", "cmd"]
action = actions[action]
Expand Down Expand Up @@ -569,3 +583,78 @@ def toggle_box(self, vis, action, mods):
self.zoom_to_box(vis)

# print(f"Toggled to box {new_mesh_id}")

def toggle_propagate_with_velocity(self, vis):
"""
Toggle the propagate_with_velocity feature on and off.
"""
self.propagate_with_velocity = not self.propagate_with_velocity
print(f"Propagate with velocity: {'On' if self.propagate_with_velocity else 'Off'}")
if self.propagate_with_velocity:
self.compute_velocities()

def compute_velocities(self):
"""
Compute velocities for all the meshes.
"""
if self.current_frame_index < 2:
return # Need at least two frames to calculate velocity
self.velocities.clear()

current_frame = self.frames[self.current_frame_index]
prev_frame = self.frames[self.current_frame_index - 1]
prev_prev_frame = self.frames[self.current_frame_index - 2]

for box, pose in current_frame.boxes.valid_boxes():
uuid = box.track_uuid
prev_pose = self.find_pose_in_frame(prev_frame, uuid)
prev_prev_pose = self.find_pose_in_frame(prev_prev_frame, uuid)

if prev_pose and prev_prev_pose:
velocity = self.calculate_velocity(prev_prev_pose, prev_pose)
self.velocities[uuid] = velocity
else:
self.velocities[uuid] = np.zeros(3) # No movement if no corresponding box found

def find_pose_in_frame(self, frame, track_uuid: str):
"""
Find the pose of the box according to track_uuid
"""
for box, pose in frame.boxes.valid_boxes():
if box.track_uuid == track_uuid:
return pose
return None

def calculate_velocity(self, pose1: PoseInfo, pose2: PoseInfo) -> np.ndarray:
"""
Calculate displacement between two poses (suppose time interval=1 and use it as velocity).
"""
translation1 = pose1.sensor_to_global.translation
translation2 = pose2.sensor_to_global.translation
velocity = translation2 - translation1 # Update the pose of the RenderableBox
return velocity

def apply_velocity(self):
"""
Apply stored velocities to the boxes in the current frame.
"""
self.compute_velocities()

current_frame = self.frames[self.current_frame_index]
last_frame = self.frames[self.current_frame_index - 1]

for key, renderable_box in self.clickable_geometries.items():
# for box, pose in current_frame.boxes.valid_boxes():
uuid = renderable_box.base_box.track_uuid
if uuid in self.velocities:
last_pose = self.find_pose_in_frame(last_frame, uuid)
if last_pose:
# Update the current pose based on the last pose and velocity
new_translation = last_pose.sensor_to_global.translation + self.velocities[uuid]
# Create a new SE3 object for the new global pose
new_global_pose = SE3(
rotation_matrix=last_pose.sensor_to_global.rotation_matrix,
translation=new_translation,
)

renderable_box.update_from_global(new_global_pose)

0 comments on commit 2c08eee

Please sign in to comment.