Skip to content

Commit

Permalink
Updated docs with render results
Browse files Browse the repository at this point in the history
  • Loading branch information
kylevedder committed Mar 13, 2024
1 parent 6c260bc commit 4552099
Show file tree
Hide file tree
Showing 10 changed files with 443 additions and 110 deletions.
12 changes: 10 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,18 @@ python scripts/demo.py --dataset Argoverse2SceneFlow --root_dir /efs/argoverse2/
To render RGB frames with lidar imposed on top, run:

```
python scripts/demo.py --dataset Argoverse2SceneFlow --root_dir /efs/argoverse2/val/ --mode rgb --save_dir path/to/save_dir
python scripts/demo_rgb.py --dataset Argoverse2SceneFlow --mode project_lidar --save_dir /efs/av2_camera_render/ --reduction_factor 16 --root_dir /efs/argoverse2/val
```

![Argoverse 2 MultiCam Video](docs/imgs/av2_multicam_video.gif)
![Argoverse 2 MultiCam Video](docs/imgs/av2_lidar.gif)

To render the flow field of an Argoverse 2 sequence, run:

```
python scripts/demo_rgb.py --dataset Argoverse2SceneFlow --mode project_lidar --save_dir /efs/av2_camera_render/ --reduction_factor 16 --root_dir /efs/argoverse2/val --flow_dir <path to method flow output>
```

![Argoverse 2 MultiCam Video](docs/imgs/av2_gt_flow.gif)

### Waymo Open:

Expand Down
2 changes: 1 addition & 1 deletion bucketed_scene_flow_eval/datastructures/pointcloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def flow_masked(self, flow: np.ndarray, mask: np.ndarray) -> "PointCloud":
# check that flow has the same number of entries as the boolean mask.
assert (
flow.shape[0] == mask.sum()
), f"flow must have same number of entries as the number of True values in the mask, got {flow.shape[0]} and {mask.sum()}"
), f"flow must have same number of entries as the number of True values in the mask, got flow shape of {flow.shape[0]} and mask sum of {mask.sum()}"
flow = flow.astype(np.float32)
updated_points = self.points.copy()
updated_points[mask] = self.points[mask] + flow
Expand Down
18 changes: 18 additions & 0 deletions bucketed_scene_flow_eval/datastructures/rgb_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,24 @@ def __init__(self, image: np.ndarray):

self.image = image.astype(np.float32)

@staticmethod
def white_image(shape: tuple[int, int]) -> "RGBImage":
assert len(shape) == 2, f"shape must be a 2-tuple, got {shape}"
return RGBImage(np.ones(shape + (3,), dtype=np.float32))

@staticmethod
def white_image_like(image: "RGBImage") -> "RGBImage":
return RGBImage.white_image(image.shape[:2])

@staticmethod
def black_image(shape: tuple[int, int]) -> "RGBImage":
assert len(shape) == 2, f"shape must be a 2-tuple, got {shape}"
return RGBImage(np.zeros(shape + (3,), dtype=np.float32))

@staticmethod
def black_image_like(image: "RGBImage") -> "RGBImage":
return RGBImage.black_image(image.shape[:2])

def __repr__(self) -> str:
return f"RGBImage(shape={self.image.shape}, dtype={self.image.dtype})"

Expand Down
54 changes: 54 additions & 0 deletions bucketed_scene_flow_eval/datastructures/scene_sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,14 @@ class PoseInfo:
sensor_to_ego: SE3
ego_to_global: SE3

def __eq__(self, __value: object) -> bool:
if not isinstance(__value, PoseInfo):
return False
return (
self.sensor_to_ego == __value.sensor_to_ego
and self.ego_to_global == __value.ego_to_global
)

def __repr__(self) -> str:
return f"PoseInfo(sensor_to_ego={self.sensor_to_ego}, ego_to_global={self.ego_to_global})"

Expand Down Expand Up @@ -52,6 +60,42 @@ def global_pc(self) -> PointCloud:
def global_pose(self) -> SE3:
return self.pose.ego_to_global @ self.pose.sensor_to_ego

def add_global_flow(self, flow: NDArray, valid_flow_mask: NDArray) -> "PointCloudFrame":
assert flow.ndim == 2, f"flow must be a 2D array, got {flow.ndim}"
assert (
valid_flow_mask.ndim == 1
), f"valid_flow_mask must be a 1D array, got {valid_flow_mask.ndim}"
assert (
valid_flow_mask.dtype == bool
), f"valid_flow_mask must be boolean, got {valid_flow_mask.dtype}"

assert len(flow) == len(valid_flow_mask), (
f"flow and valid_flow_mask must have the same length, got {len(flow)} and "
f"{len(valid_flow_mask)}"
)

assert len(flow) == len(
self.full_pc
), f"flow shape {flow.shape} must match point cloud shape {len(self.full_pc)}"

assert self.mask.shape == valid_flow_mask.shape, (
f"mask and valid_flow_mask must have the same length, got {len(self.mask)} and "
f"{len(valid_flow_mask)}"
)

# Convert to global pc, add flow, then convert back to ego frame
flowed_ego_pc = self.full_global_pc.flow_masked(
flow[valid_flow_mask], valid_flow_mask
).transform(self.global_pose.inverse())
# Only include points that are valid and in the mask
joined_mask = self.mask & valid_flow_mask

return PointCloudFrame(
full_pc=flowed_ego_pc,
pose=self.pose,
mask=joined_mask,
)


@dataclass
class RGBFrame:
Expand Down Expand Up @@ -298,6 +342,16 @@ def __init__(
def valid_particle_ids(self) -> NDArray:
return np.arange(self.num_entries)[self.is_valid_flow]

def get_flow(
self, src_timestamp: Timestamp, target_timestamp: Timestamp
) -> tuple[NDArray, NDArray]:
src_idx = np.where(self.trajectory_timestamps == src_timestamp)[0][0]
target_idx = np.where(self.trajectory_timestamps == target_timestamp)[0][0]
return (
self.world_points[:, target_idx] - self.world_points[:, src_idx],
self.is_valid_flow,
)

def __len__(self) -> int:
return self.is_valid_flow.sum()

Expand Down
7 changes: 7 additions & 0 deletions bucketed_scene_flow_eval/datastructures/se3.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,13 @@ def from_rot_w_x_y_z_translation_x_y_z(rw, rx, ry, rz, tx, ty, tz) -> "SE3":
translation = np.array([tx, ty, tz])
return SE3(rotation_matrix, translation)

def __eq__(self, __value: object) -> bool:
if not isinstance(__value, SE3):
return False
return np.allclose(self.rotation_matrix, __value.rotation_matrix) and np.allclose(
self.translation, __value.translation
)

def translate(self, translation: np.ndarray) -> "SE3":
"""Return a new SE3 instance with the given translation applied."""
if isinstance(translation, list):
Expand Down
Binary file added docs/imgs/av2_gt_flow.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file removed docs/imgs/av2_multicam_video.mp4
Binary file not shown.
109 changes: 2 additions & 107 deletions scripts/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,66 +6,7 @@
from matplotlib import pyplot as plt

from bucketed_scene_flow_eval.datasets import construct_dataset
from bucketed_scene_flow_eval.datastructures import (
GroundTruthPointFlow,
O3DVisualizer,
PointCloud,
PointCloudFrame,
QuerySceneSequence,
RGBFrame,
RGBFrameLookup,
RGBImage,
)


def color_by_distance(distances: np.ndarray, max_distance: float = 10.0, cmap: str = "viridis"):
# Use distance to color points, normalized to [0, 1].
colors = distances.copy()

# Normalize to [0, 1]
colors = colors / max_distance
colors[colors > 1] = 1.0

colormap = plt.get_cmap(cmap)
colors = colormap(colors)[:, :3]
return colors


def process_lidar_only(o3d_vis: O3DVisualizer, pc_frame: PointCloudFrame):
print(f"Adding Lidar pointcloud with {len(pc_frame.global_pc)} points")
o3d_vis.add_pointcloud(pc_frame.global_pc)
o3d_vis.run()


def project_lidar_into_rgb(
pc_frame: PointCloudFrame, rgb_frame: RGBFrame, reduction_factor: int = 4
) -> RGBImage:
pc_into_cam_frame_se3 = pc_frame.pose.sensor_to_ego.inverse().compose(
rgb_frame.pose.sensor_to_ego
)
cam_frame_pc = pc_frame.full_pc.transform(pc_into_cam_frame_se3)
cam_frame_pc = PointCloud(cam_frame_pc.points[cam_frame_pc.points[:, 0] >= 0])

projected_points = rgb_frame.camera_projection.camera_frame_to_pixels(cam_frame_pc.points)
projected_points = projected_points.astype(np.int32)

# Use distance to color points, normalized to [0, 1].
colors = color_by_distance(cam_frame_pc.points[:, 0], max_distance=30)
valid_points_mask = (
(projected_points[:, 0] >= 0)
& (projected_points[:, 0] < rgb_frame.rgb.image.shape[1])
& (projected_points[:, 1] >= 0)
& (projected_points[:, 1] < rgb_frame.rgb.image.shape[0])
)
projected_points = projected_points[valid_points_mask]
colors = colors[valid_points_mask]

scaled_rgb = rgb_frame.rgb.rescale(reduction_factor)
scaled_projected_points = projected_points // reduction_factor

projected_rgb_image = scaled_rgb.image
projected_rgb_image[scaled_projected_points[:, 1], scaled_projected_points[:, 0], :] = colors
return RGBImage(projected_rgb_image)
from bucketed_scene_flow_eval.datastructures import O3DVisualizer, QuerySceneSequence


def visualize_lidar_3d(query: QuerySceneSequence):
Expand All @@ -86,55 +27,12 @@ def visualize_lidar_3d(query: QuerySceneSequence):
del o3d_vis


def visualize_rgb(frame_idx: int, query: QuerySceneSequence, save_dir: Optional[Path] = None):
# The query specifies the raw scene and query points at a particular timestamp
# These query points can be thought of as the specification of the valid points for
# scene flow in the pointcloud at `t` for prediction to timestamp `t+1`

scene_timestamp = query.query_particles.query_init_timestamp

# The scene contains RGB image and pointcloud data for each timestamp.
# These are stored as "frames" with pose and intrinsics information.
# This enables the raw percepts to be projected into desired coordinate frames across time.
rgb_frames = query.scene_sequence[scene_timestamp].rgb_frames
pc_frame = query.scene_sequence[scene_timestamp].pc_frame

items = rgb_frames.items()
for plot_idx, (name, rgb_frame) in enumerate(items):
plt.subplot(1, len(items), plot_idx + 1)
plt.imshow(project_lidar_into_rgb(pc_frame, rgb_frame).image)
# Disable axis ticks
plt.xticks([])
plt.yticks([])
# Set padding between subplots to 0
plt.tight_layout(pad=0)
# Get rid of black border
# plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
# Get rid of white space
plt.margins(0)
ax = plt.gca()
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.spines["bottom"].set_visible(False)
ax.spines["left"].set_visible(False)

if save_dir is None:
plt.show()
else:
save_location = save_dir / f"{query.scene_sequence.log_id}" / f"{frame_idx:010d}.png"
save_location.parent.mkdir(parents=True, exist_ok=True)
plt.savefig(save_location, bbox_inches="tight", pad_inches=0, dpi=200)
plt.clf()


if __name__ == "__main__":
# Take arguments to specify dataset and root directory
parser = argparse.ArgumentParser()
parser.add_argument("--dataset", type=str, default="Argoverse2SceneFlow")
parser.add_argument("--root_dir", type=str, default="/efs/argoverse2/val")
parser.add_argument("--skip_rgb", action="store_true")
parser.add_argument("--mode", choices=["lidar", "rgb"], default="lidar")
parser.add_argument("--save_dir", type=Path, default=None)
args = parser.parse_args()

dataset = construct_dataset(
Expand All @@ -144,7 +42,4 @@ def visualize_rgb(frame_idx: int, query: QuerySceneSequence, save_dir: Optional[
print("Dataset contains", len(dataset), "samples")

for idx, (query, gt) in enumerate(dataset):
if args.mode == "rgb":
visualize_rgb(idx, query, args.save_dir)
else:
visualize_lidar_3d(query)
visualize_lidar_3d(query)
Loading

0 comments on commit 4552099

Please sign in to comment.