Skip to content

Commit

Permalink
Updated demo and frames of AV2 RGB to be consistent
Browse files Browse the repository at this point in the history
  • Loading branch information
kylevedder committed Mar 10, 2024
1 parent ed2e18b commit 84da1b3
Show file tree
Hide file tree
Showing 6 changed files with 134 additions and 88 deletions.
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ To run with Argoverse 2:
python scripts/demo.py --dataset Argoverse2SceneFlow --root_dir /efs/argoverse2/val/
```

<!-- Image of docs/imgs/av2_multicam.png -->

![Argoverse 2 MultiCam](docs/imgs/av2_multicam.png)


To run with Waymo Open:

```
Expand Down
142 changes: 85 additions & 57 deletions bucketed_scene_flow_eval/datasets/argoverse2/argoverse_raw_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,41 @@
GROUND_HEIGHT_THRESHOLD = 0.4 # 40 centimeters


@dataclass(kw_only=True)
class CameraInfo:
rgb_frame_paths: list[Path]
rgb_timestamp_to_rgb_file_map: dict[int, Path]
timestamp_to_rgb_timestamp_map: dict[int, int]
rgb_camera_projection: CameraProjection
rgb_camera_ego_pose: SE3

def timestamp_to_rgb_path(self, timestamp: int) -> Path:
assert timestamp in self.timestamp_to_rgb_timestamp_map, f"timestamp {timestamp} not found"
rgb_timestamp = self.timestamp_to_rgb_timestamp_map[timestamp]
assert (
rgb_timestamp in self.rgb_timestamp_to_rgb_file_map
), f"rgb_timestamp {rgb_timestamp} not found"
return self.rgb_timestamp_to_rgb_file_map[rgb_timestamp]

def load_rgb(self, timestamp: int) -> RGBImage:
rgb_path = self.timestamp_to_rgb_path(timestamp)
# Read the image, keep the same color space
raw_img = cv2.imread(str(rgb_path), cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.0
# Convert from CV2 standard BGR to RGB
raw_img = cv2.cvtColor(raw_img, cv2.COLOR_BGR2RGB)
return RGBImage(raw_img)

def load_rgb_frame(self, timestamp: int, global_pose: SE3) -> RGBFrame:
rgb = self.load_rgb(timestamp)
return RGBFrame(
rgb=rgb,
pose=PoseInfo(
sensor_to_ego=self.rgb_camera_ego_pose.inverse(), ego_to_global=global_pose
),
camera_projection=self.rgb_camera_projection,
)


class ArgoverseRawSequence(AbstractSequence):
"""
Argoverse Raw Sequence.
Expand All @@ -47,6 +82,13 @@ def __init__(
with_rgb: bool = False,
POINT_CLOUD_RANGE=(-48, -48, -2.5, 48, 48, 2.5),
sample_every: Optional[int] = None,
camera_names=[
"ring_front_center",
"ring_front_right",
"ring_front_left",
"ring_side_right",
"ring_side_left",
],
):
self.log_id = log_id
self.POINT_CLOUD_RANGE = POINT_CLOUD_RANGE
Expand All @@ -65,40 +107,19 @@ def __init__(
(
self.lidar_frame_paths,
self.timestamp_to_lidar_file_map,
lidar_file_timestamps,
self.lidar_file_timestamps,
) = self._load_lidar_info()

self.with_rgb = with_rgb

if with_rgb:
# Load the RGB frame information.
camera_name = "ring_front_center"
(
self.rgb_frame_paths,
self.rgb_timestamp_to_rgb_file_map,
) = self._load_rgb_info(camera_name)

# Load the RGB intrinsics.
self.rgb_camera_projection = self._load_camera_projection(camera_name)
self.rgb_camera_ego_pose = self._load_camera_ego_pose(camera_name)

# Find the nearest RGB percept to each lidar frame.
# This is N^2. TODO: Figure out if this is a bottleneck.
self.timestamp_to_rgb_timestamp_map = {
lidar_timestamp: min(
self.rgb_timestamp_to_rgb_file_map.keys(),
key=lambda rgb_timestamp: abs(rgb_timestamp - lidar_timestamp),
)
for lidar_timestamp in lidar_file_timestamps
}
else:
self.rgb_frame_paths = None
self.rgb_timestamp_to_rgb_file_map = {}
self.timestamp_to_rgb_timestamp_map = {}
self.rgb_camera_projection = None
self.rgb_camera_ego_pose = None

self.timestamp_list = sorted(lidar_file_timestamps.intersection(info_timestamps))
if not with_rgb:
camera_names = []

self.camera_info_lookup: dict[str, CameraInfo] = {
camera_name: self._prep_camera_info(camera_name) for camera_name in camera_names
}

self.timestamp_list = sorted(self.lidar_file_timestamps.intersection(info_timestamps))
assert len(self.timestamp_list) > 0, f"no timestamps found in {self.dataset_dir}"

if sample_every is not None:
Expand All @@ -115,6 +136,31 @@ def __init__(
f"Loaded {len(self.timestamp_list)} frames from {self.dataset_dir} at timestamp {time.time():.3f}"
)

def _prep_camera_info(self, camera_name: str) -> CameraInfo:
(
rgb_frame_paths,
rgb_timestamp_to_rgb_file_map,
) = self._load_rgb_info(camera_name)

rgb_camera_projection = self._load_camera_projection(camera_name)
rgb_camera_ego_pose = self._load_camera_ego_pose(camera_name)

timestamp_to_rgb_timestamp_map = {
lidar_timestamp: min(
rgb_timestamp_to_rgb_file_map.keys(),
key=lambda rgb_timestamp: abs(rgb_timestamp - lidar_timestamp),
)
for lidar_timestamp in self.lidar_file_timestamps
}

return CameraInfo(
rgb_frame_paths=rgb_frame_paths,
rgb_timestamp_to_rgb_file_map=rgb_timestamp_to_rgb_file_map,
timestamp_to_rgb_timestamp_map=timestamp_to_rgb_timestamp_map,
rgb_camera_projection=rgb_camera_projection,
rgb_camera_ego_pose=rgb_camera_ego_pose,
)

def _load_lidar_info(self):
# Load the lidar frame information.
lidar_frame_directory = self.dataset_dir / "sensors" / "lidar"
Expand Down Expand Up @@ -176,9 +222,13 @@ def _load_camera_ego_pose(self, sensor_name: str) -> SE3:
rotation = self._quat_to_mat(qw, qx, qy, qz)
translation = np.array([tx, ty, tz])

# fmt: off
coordinate_transform_matrix = np.array(
[[0, -1, 0], [0, 0, -1], [1, 0, 0]]
) # noqa # noqa # noqa
[[0, -1, 0],
[0, 0, -1],
[1, 0, 0]],
)
# fmt: on

rotation = rotation @ coordinate_transform_matrix

Expand Down Expand Up @@ -287,17 +337,6 @@ def _load_pc(self, idx) -> PointCloud:
points = np.stack([xs, ys, zs], axis=1)
return PointCloud(points)

def _load_rgb(self, idx) -> RGBImage:
assert idx < len(self), f"idx {idx} out of range, len {len(self)} for {self.dataset_dir}"
timestamp = self.timestamp_list[idx]
rgb_timestamp = self.timestamp_to_rgb_timestamp_map[timestamp]
rgb_path = self.rgb_timestamp_to_rgb_file_map[rgb_timestamp]
# Read the image, keep the same color space
raw_img = cv2.imread(str(rgb_path), cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.0
# Convert from CV2 standard BGR to RGB
raw_img = cv2.cvtColor(raw_img, cv2.COLOR_BGR2RGB)
return RGBImage(raw_img)

def _load_pose(self, idx) -> SE3:
assert idx < len(self), f"idx {idx} out of range, len {len(self)} for {self.dataset_dir}"
timestamp = self.timestamp_list[idx]
Expand All @@ -318,23 +357,15 @@ def load(self, idx: int, relative_to_idx: int) -> RawItem:
assert idx < len(self), f"idx {idx} out of range, len {len(self)} for {self.dataset_dir}"
timestamp = self.timestamp_list[idx]
ego_pc = self._load_pc(idx)
if self.with_rgb:
img = self._load_rgb(idx)
else:
img = None

start_pose = self._load_pose(relative_to_idx)
idx_pose = self._load_pose(idx)
relative_pose = start_pose.inverse().compose(idx_pose)
absolute_global_frame_pc = ego_pc.transform(idx_pose)
is_ground_points = self.is_ground_points(absolute_global_frame_pc)
relative_global_frame_pc_with_ground = ego_pc.transform(relative_pose)
relative_global_frame_pc_no_ground = relative_global_frame_pc_with_ground.mask_points(
~is_ground_points
)
ego_pc_no_ground = ego_pc.mask_points(~is_ground_points)

in_range_mask_with_ground = self.is_in_range(relative_global_frame_pc_with_ground)
in_range_mask_no_ground = self.is_in_range(relative_global_frame_pc_no_ground)

pc_frame = PointCloudFrame(
full_pc=ego_pc,
Expand All @@ -343,11 +374,8 @@ def load(self, idx: int, relative_to_idx: int) -> RawItem:
)

rgb_frames = [
RGBFrame(
rgb=img,
pose=PoseInfo(sensor_to_ego=self.rgb_camera_ego_pose, ego_to_global=relative_pose),
camera_projection=self.rgb_camera_projection,
)
camera.load_rgb_frame(timestamp, relative_pose)
for _, camera in sorted(self.camera_info_lookup.items())
]

return RawItem(
Expand Down
1 change: 1 addition & 0 deletions bucketed_scene_flow_eval/datasets/waymoopen/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def __init__(
subsequence_length: int = 2,
cache_path: Path = Path("/tmp/"),
eval_type: str = "bucketed_epe",
with_rgb: bool = False,
eval_args=dict(),
) -> None:
self.root_dir = Path(root_dir)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ def __init__(
sequence_dir: Path,
log_subset: Optional[list[str]] = None,
verbose: bool = False,
with_rgb: bool = False,
):
super().__init__()
self.dataset_dir = Path(sequence_dir)
Expand Down
Binary file added docs/imgs/av2_multicam.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
73 changes: 42 additions & 31 deletions scripts/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,46 +35,57 @@ def process_lidar_only(o3d_vis: O3DVisualizer, pc_frame: PointCloudFrame):
def process_lidar_rgb(
o3d_vis: O3DVisualizer, pc_frame: PointCloudFrame, rgb_frames: list[RGBFrame]
):
rgb_frame = rgb_frames[0]
image_plane_pc, colors = rgb_frame.camera_projection.image_to_image_plane_pc(
rgb_frame.rgb, depth=10
)
# Plot the pointcloud

pc_into_cam_frame_se3 = pc_frame.pose.sensor_to_ego.compose(
rgb_frame.pose.sensor_to_ego.inverse()
)
cam_frame_pc = pc_frame.full_pc.transform(pc_into_cam_frame_se3)
o3d_vis.add_pointcloud(pc_frame.global_pc)
for rgb_frame in rgb_frames:
image_plane_pc, colors = rgb_frame.camera_projection.image_to_image_plane_pc(
rgb_frame.rgb, depth=20
)
image_plane_pc = image_plane_pc.transform(rgb_frame.pose.sensor_to_ego.inverse())
o3d_vis.add_pointcloud(image_plane_pc, color=colors)
o3d_vis.run()

# To prevent points behind the camera from being projected into the image, we had to remove them from the pointcloud.
# These points have a negative X value in the camera frame.
# rgb_frame = rgb_frames[0]
# image_plane_pc, colors = rgb_frame.camera_projection.image_to_image_plane_pc(
# rgb_frame.rgb, depth=10
# )

cam_frame_pc = PointCloud(cam_frame_pc.points[cam_frame_pc.points[:, 0] >= 0])
# pc_into_cam_frame_se3 = pc_frame.pose.sensor_to_ego.compose(
# rgb_frame.pose.sensor_to_ego.inverse()
# )
# cam_frame_pc = pc_frame.full_pc.transform(pc_into_cam_frame_se3)

o3d_vis.add_pointcloud(cam_frame_pc)
o3d_vis.add_pointcloud(image_plane_pc, color=colors)
o3d_vis.run()
# # To prevent points behind the camera from being projected into the image, we had to remove them from the pointcloud.
# # These points have a negative X value in the camera frame.

projected_points = rgb_frame.camera_projection.camera_frame_to_pixels(cam_frame_pc.points)
projected_points = projected_points.astype(np.int32)
# cam_frame_pc = PointCloud(cam_frame_pc.points[cam_frame_pc.points[:, 0] >= 0])

# Use distance to color points, normalized to [0, 1]. Let points more than 10m away be black.
colors = color_threshold_distance(cam_frame_pc.points[:, 0], max_distance=10)
# o3d_vis.add_pointcloud(cam_frame_pc)
# o3d_vis.add_pointcloud(image_plane_pc, color=colors)
# o3d_vis.run()

# Mask out points that are out of bounds
# projected_points = rgb_frame.camera_projection.camera_frame_to_pixels(cam_frame_pc.points)
# projected_points = projected_points.astype(np.int32)

valid_points_mask = (
(projected_points[:, 0] >= 0)
& (projected_points[:, 0] < rgb_frame.rgb.image.shape[1])
& (projected_points[:, 1] >= 0)
& (projected_points[:, 1] < rgb_frame.rgb.image.shape[0])
)
projected_points = projected_points[valid_points_mask]
colors = colors[valid_points_mask]
# # Use distance to color points, normalized to [0, 1]. Let points more than 10m away be black.
# colors = color_threshold_distance(cam_frame_pc.points[:, 0], max_distance=10)

# # Mask out points that are out of bounds

# valid_points_mask = (
# (projected_points[:, 0] >= 0)
# & (projected_points[:, 0] < rgb_frame.rgb.image.shape[1])
# & (projected_points[:, 1] >= 0)
# & (projected_points[:, 1] < rgb_frame.rgb.image.shape[0])
# )
# projected_points = projected_points[valid_points_mask]
# colors = colors[valid_points_mask]

projected_rgb_image = rgb_frame.rgb.image
projected_rgb_image[projected_points[:, 1], projected_points[:, 0], :] = colors
plt.imshow(projected_rgb_image)
plt.show()
# projected_rgb_image = rgb_frame.rgb.image
# projected_rgb_image[projected_points[:, 1], projected_points[:, 0], :] = colors
# plt.imshow(projected_rgb_image)
# plt.show()


def process_entry(query: QuerySceneSequence, gt: GroundTruthPointFlow):
Expand Down

0 comments on commit 84da1b3

Please sign in to comment.