From 0ba74c23ee1a5fa3dcd561546428c88e56da0f05 Mon Sep 17 00:00:00 2001 From: Kyle Vedder Date: Thu, 22 Aug 2024 15:26:14 -0400 Subject: [PATCH] Added fix for flow loading --- .../waymoopen/waymo_supervised_flow.py | 21 +++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/bucketed_scene_flow_eval/datasets/waymoopen/waymo_supervised_flow.py b/bucketed_scene_flow_eval/datasets/waymoopen/waymo_supervised_flow.py index a1d8825..0231036 100644 --- a/bucketed_scene_flow_eval/datasets/waymoopen/waymo_supervised_flow.py +++ b/bucketed_scene_flow_eval/datasets/waymoopen/waymo_supervised_flow.py @@ -34,11 +34,11 @@ def __init__(self, sequence_folder: Path, flow_folder: Path | None, verbose: boo self.sequence_folder = Path(sequence_folder) self.sequence_files = sorted(self.sequence_folder.glob("*.pkl")) if flow_folder is not None: - self.flow_folder = Path(flow_folder) - self.flow_files = sorted(self.flow_folder.glob("*.feather")) - assert len(self.sequence_files) == len(self.flow_files), ( + self.flow_folder: Path | None = Path(flow_folder) + self.flow_files: list[Path] | None = sorted(self.flow_folder.glob("*.feather")) + assert len(self.sequence_files) - 1 == len(self.flow_files), ( f"number of frames in {self.sequence_folder} does not match number of frames in " - f"{self.flow_folder}" + f"{self.flow_folder}; {len(self.sequence_files)} vs {len(self.flow_files)}" ) else: self.flow_folder = None @@ -60,9 +60,18 @@ def _load_idx(self, idx: int): pkl = load_pickle(pickle_path, verbose=False) pc = PointCloud(pkl["car_frame_pc"]) flow = pkl["flow"] - if self.flow_folder is not None: + if self.flow_files is not None: flow_path = self.flow_files[idx] - flow = load_feather(flow_path).to_numpy() + + flow_df = load_feather(flow_path) + flow = flow_df[["flow_tx_m", "flow_ty_m", "flow_tz_m"]].values + marked_is_valid = flow_df["is_valid"].values + flow = flow.astype(np.float32) + # Zero out invalid flow values + flow[~marked_is_valid] = 0 + assert isinstance(flow, np.ndarray), f"flow is not a numpy array: {type(flow)}" + assert flow.shape[1] == 3, f"flow has shape {flow.shape} instead of (N, 3)" + assert flow.dtype == np.float32, f"flow has dtype {flow.dtype} instead of float32" assert len(flow) == len( pc ), f"number of points in flow {len(flow)} does not match number of points in pc {len(pc)}"