From 0e4ab40b8077445b189c1af7aa98f50c35928677 Mon Sep 17 00:00:00 2001
From: Kyle Vedder <kyle.c.vedder@gmail.com>
Date: Thu, 14 Mar 2024 15:42:56 -0400
Subject: [PATCH] Bunch of stuff is still broken, but AV2 Demo 3D works for RGB
 + Lidar

---
 bucketed_scene_flow_eval/datasets/__init__.py |   8 +-
 .../datasets/argoverse2/argoverse_raw_data.py |  36 +-
 .../argoverse2/argoverse_scene_flow.py        | 117 +++--
 .../datasets/argoverse2/dataset.py            | 165 +------
 .../datasets/nuscenes/nuscenes_loader.py      |   2 +-
 .../scene_representations.py                  |  32 --
 .../datasets/waymoopen/dataset.py             |  18 +-
 .../waymoopen/waymo_supervised_flow.py        |   2 +-
 .../datastructures/__init__.py                |  65 ++-
 .../datastructures/camera_projection.py       |   9 +
 .../datastructures/dataclasses.py             | 231 ++++++++++
 .../datastructures/o3d_visualizer.py          |  36 +-
 .../datastructures/pointcloud.py              |   2 +-
 .../datastructures/scene_sequence.py          | 427 ------------------
 .../eval/base_per_frame_sceneflow_eval.py     | 214 +++------
 bucketed_scene_flow_eval/eval/bucketed_epe.py |  13 +-
 bucketed_scene_flow_eval/eval/eval.py         |  12 +-
 .../__init__.py                               |   6 +-
 .../interfaces/abstract_dataset.py            |  14 +
 .../abstract_sequence_loader.py               |   9 +-
 scripts/demo.py                               |  45 --
 scripts/demo_3d.py                            |  54 +++
 tests/argoverse2/av2_tests.py                 |  16 +-
 tests/eval/bucketed_epe.py                    |  37 +-
 tests/integration_tests.py                    | 125 +----
 25 files changed, 625 insertions(+), 1070 deletions(-)
 delete mode 100644 bucketed_scene_flow_eval/datasets/shared_datastructures/scene_representations.py
 create mode 100644 bucketed_scene_flow_eval/datastructures/dataclasses.py
 delete mode 100644 bucketed_scene_flow_eval/datastructures/scene_sequence.py
 rename bucketed_scene_flow_eval/{datasets/shared_datastructures => interfaces}/__init__.py (65%)
 create mode 100644 bucketed_scene_flow_eval/interfaces/abstract_dataset.py
 rename bucketed_scene_flow_eval/{datasets/shared_datastructures => interfaces}/abstract_sequence_loader.py (83%)
 delete mode 100644 scripts/demo.py
 create mode 100644 scripts/demo_3d.py

diff --git a/bucketed_scene_flow_eval/datasets/__init__.py b/bucketed_scene_flow_eval/datasets/__init__.py
index 07beb04..6bc0483 100644
--- a/bucketed_scene_flow_eval/datasets/__init__.py
+++ b/bucketed_scene_flow_eval/datasets/__init__.py
@@ -1,11 +1,13 @@
 from bucketed_scene_flow_eval.datasets.argoverse2 import Argoverse2SceneFlow
-from bucketed_scene_flow_eval.datasets.waymoopen import WaymoOpenSceneFlow
+from bucketed_scene_flow_eval.interfaces import AbstractDataset
 
-importable_classes = [Argoverse2SceneFlow, WaymoOpenSceneFlow]
+# from bucketed_scene_flow_eval.datasets.waymoopen import WaymoOpenSceneFlow
+
+importable_classes = [Argoverse2SceneFlow]  # , WaymoOpenSceneFlow]
 name_to_class_lookup = {cls.__name__.lower(): cls for cls in importable_classes}
 
 
-def construct_dataset(name: str, args: dict):
+def construct_dataset(name: str, args: dict) -> AbstractDataset:
     name = name.lower()
     if name not in name_to_class_lookup:
         raise ValueError(f"Unknown dataset name: {name}")
diff --git a/bucketed_scene_flow_eval/datasets/argoverse2/argoverse_raw_data.py b/bucketed_scene_flow_eval/datasets/argoverse2/argoverse_raw_data.py
index e1e6d02..d1c1397 100644
--- a/bucketed_scene_flow_eval/datasets/argoverse2/argoverse_raw_data.py
+++ b/bucketed_scene_flow_eval/datasets/argoverse2/argoverse_raw_data.py
@@ -8,12 +8,6 @@
 import pandas as pd
 from scipy.spatial.transform import Rotation
 
-from bucketed_scene_flow_eval.datasets.shared_datastructures import (
-    AbstractSequence,
-    AbstractSequenceLoader,
-    CachedSequenceLoader,
-    RawItem,
-)
 from bucketed_scene_flow_eval.datastructures import (
     SE2,
     SE3,
@@ -25,7 +19,10 @@
     RGBFrame,
     RGBFrameLookup,
     RGBImage,
+    TimeSyncedAVLidarData,
+    TimeSyncedRawItem,
 )
+from bucketed_scene_flow_eval.interfaces import AbstractSequence, CachedSequenceLoader
 from bucketed_scene_flow_eval.utils import load_json
 
 GROUND_HEIGHT_THRESHOLD = 0.4  # 40 centimeters
@@ -356,7 +353,9 @@ def _load_pose(self, idx) -> SE3:
         )
         return se3
 
-    def load(self, idx: int, relative_to_idx: int) -> RawItem:
+    def load(
+        self, idx: int, relative_to_idx: int
+    ) -> tuple[TimeSyncedRawItem, TimeSyncedAVLidarData]:
         assert idx < len(self), f"idx {idx} out of range, len {len(self)} for {self.dataset_dir}"
         timestamp = self.timestamp_list[idx]
         ego_pc = self._load_pc(idx)
@@ -384,17 +383,22 @@ def load(self, idx: int, relative_to_idx: int) -> RawItem:
             self.camera_names,
         )
 
-        return RawItem(
-            pc=pc_frame,
-            rgbs=rgb_frames,
-            is_ground_points=is_ground_points,
-            in_range_mask=in_range_mask_with_ground,
-            log_id=self.log_id,
-            log_idx=idx,
-            log_timestamp=timestamp,
+        return (
+            TimeSyncedRawItem(
+                pc=pc_frame,
+                rgbs=rgb_frames,
+                log_id=self.log_id,
+                log_idx=idx,
+                log_timestamp=timestamp,
+            ),
+            TimeSyncedAVLidarData(
+                is_ground_points=is_ground_points, in_range_mask=in_range_mask_with_ground
+            ),
         )
 
-    def load_frame_list(self, relative_to_idx: Optional[int]) -> list[RawItem]:
+    def load_frame_list(
+        self, relative_to_idx: Optional[int]
+    ) -> list[TimeSyncedRawItem, TimeSyncedAVLidarData]:
         return [
             self.load(idx, relative_to_idx if relative_to_idx is not None else idx)
             for idx in range(len(self))
diff --git a/bucketed_scene_flow_eval/datasets/argoverse2/argoverse_scene_flow.py b/bucketed_scene_flow_eval/datasets/argoverse2/argoverse_scene_flow.py
index f8b9193..bf3f7e7 100644
--- a/bucketed_scene_flow_eval/datasets/argoverse2/argoverse_scene_flow.py
+++ b/bucketed_scene_flow_eval/datasets/argoverse2/argoverse_scene_flow.py
@@ -5,14 +5,20 @@
 
 import numpy as np
 
-from bucketed_scene_flow_eval.datasets.shared_datastructures import (
-    AbstractSequence,
-    AbstractSequenceLoader,
-    CachedSequenceLoader,
-    RawItem,
-    SceneFlowItem,
+from bucketed_scene_flow_eval.datastructures import (
+    EgoLidarFlow,
+    MaskArray,
+    PointCloud,
+    PointCloudFrame,
+    SemanticClassId,
+    SemanticClassIdArray,
+    SupervisedPointCloudFrame,
+    TimeSyncedAVLidarData,
+    TimeSyncedRawItem,
+    TimeSyncedSceneFlowItem,
+    VectorArray,
 )
-from bucketed_scene_flow_eval.datastructures import PointCloud, PointCloudFrame
+from bucketed_scene_flow_eval.interfaces import CachedSequenceLoader
 from bucketed_scene_flow_eval.utils.loaders import load_feather
 
 from . import ArgoverseRawSequence
@@ -81,22 +87,25 @@ def _prep_flow(self, flow_dir: Path):
         self.timestamp_list = self.timestamp_list[: len(self.flow_data_files) + 1]
 
     @staticmethod
-    def get_class_str(class_id: int) -> Optional[str]:
-        if class_id not in CATEGORY_MAP:
+    def get_class_str(class_id: SemanticClassId) -> Optional[str]:
+        class_id_int = int(class_id)
+        if class_id_int not in CATEGORY_MAP:
             return None
-        return CATEGORY_MAP[class_id]
+        return CATEGORY_MAP[class_id_int]
 
-    def _make_default_classes(self, pc: PointCloud) -> np.ndarray:
-        return np.ones(len(pc.points), dtype=np.int32) * CATEGORY_MAP_INV["BACKGROUND"]
+    def _make_default_classes(self, pc: PointCloud) -> SemanticClassIdArray:
+        return np.ones(len(pc.points), dtype=SemanticClassId) * CATEGORY_MAP_INV["BACKGROUND"]
 
-    def _load_flow(
-        self, idx, classes_0: np.ndarray
-    ) -> tuple[Optional[np.ndarray], Optional[np.ndarray], np.ndarray]:
+    def _load_flow_feather(
+        self, idx: int, classes_0: SemanticClassIdArray
+    ) -> tuple[VectorArray, MaskArray, SemanticClassIdArray]:
         assert idx < len(self), f"idx {idx} out of range, len {len(self)} for {self.dataset_dir}"
         # There is no flow information for the last pointcloud in the sequence.
 
-        if idx == len(self) - 1 or idx == -1:
-            return None, None, classes_0
+        assert (
+            idx != len(self) - 1
+        ), f"idx {idx} is the last frame in the sequence, which has no flow data"
+        assert idx >= 0, f"idx {idx} is out of range"
         flow_data_file = self.flow_data_files[idx]
         flow_data = load_feather(flow_data_file, verbose=False)
         is_valid_arr = flow_data["is_valid"].values
@@ -113,28 +122,45 @@ def _load_flow(
 
         return flow_0_1, is_valid_arr, classes_0
 
-    def _load_no_flow(self, raw_item: RawItem) -> SceneFlowItem:
-        classes_0 = self._make_default_classes(raw_item.pc.pc)
-
-        return SceneFlowItem(
-            **vars(raw_item), pc_classes=classes_0, flowed_pc=copy.deepcopy(raw_item.pc)
+    def _make_tssf_item(
+        self, raw_item: TimeSyncedRawItem, classes_0: SemanticClassIdArray, flow: EgoLidarFlow
+    ) -> TimeSyncedSceneFlowItem:
+        supervised_pc = SupervisedPointCloudFrame(
+            **vars(raw_item.pc),
+            full_pc_classes=classes_0,
         )
+        return TimeSyncedSceneFlowItem(
+            pc=supervised_pc,
+            rgbs=raw_item.rgbs,
+            log_id=raw_item.log_id,
+            log_idx=raw_item.log_idx,
+            log_timestamp=raw_item.log_timestamp,
+            flow=flow,
+        )
+
+    def _load_no_flow(
+        self, raw_item: TimeSyncedRawItem, metadata: TimeSyncedAVLidarData
+    ) -> tuple[TimeSyncedSceneFlowItem, TimeSyncedAVLidarData]:
+        classes_0 = self._make_default_classes(raw_item.pc.pc)
+        flow = EgoLidarFlow.make_no_flow(len(classes_0))
+        return self._make_tssf_item(raw_item, classes_0, flow), metadata
 
-    def _load_with_flow(self, raw_item: RawItem, idx: int, relative_to_idx: int) -> SceneFlowItem:
+    def _load_with_flow(
+        self,
+        raw_item: TimeSyncedRawItem,
+        metadata: TimeSyncedAVLidarData,
+        idx: int,
+        relative_to_idx: int,
+    ) -> tuple[TimeSyncedSceneFlowItem, TimeSyncedAVLidarData]:
         start_pose = self._load_pose(relative_to_idx)
         idx_pose = self._load_pose(idx)
         relative_pose = start_pose.inverse().compose(idx_pose)
 
-        classes_0_with_ground = self._make_default_classes(raw_item.pc.pc)
         (
             relative_global_frame_flow_0_1_with_ground,
             is_valid_flow_with_ground_arr,
             classes_0_with_ground,
-        ) = self._load_flow(idx, classes_0_with_ground)
-
-        assert (
-            relative_global_frame_flow_0_1_with_ground is not None
-        ), f"Flow data missing for {idx}"
+        ) = self._load_flow_feather(idx, self._make_default_classes(raw_item.pc.pc))
 
         relative_global_frame_with_ground_flowed_pc = raw_item.pc.global_pc.copy()
         relative_global_frame_with_ground_flowed_pc.points[
@@ -145,24 +171,23 @@ def _load_with_flow(self, raw_item: RawItem, idx: int, relative_to_idx: int) ->
             relative_pose.inverse()
         )
 
-        return SceneFlowItem(
-            **vars(raw_item),
-            pc_classes=classes_0_with_ground,
-            flowed_pc=PointCloudFrame(
-                full_pc=ego_flowed_pc_with_ground, pose=raw_item.pc.pose, mask=raw_item.pc.mask
-            ),
-        )
+        delta_flow = ego_flowed_pc_with_ground.points - raw_item.pc.full_global_pc.points
+
+        flow = EgoLidarFlow(full_flow=delta_flow, mask=is_valid_flow_with_ground_arr)
+        return (self._make_tssf_item(raw_item, classes_0_with_ground, flow), metadata)
 
-    def load(self, idx: int, relative_to_idx: int, with_flow: bool = True) -> SceneFlowItem:
+    def load(
+        self, idx: int, relative_to_idx: int, with_flow: bool = True
+    ) -> tuple[TimeSyncedSceneFlowItem, TimeSyncedAVLidarData]:
         assert idx < len(self), f"idx {idx} out of range, len {len(self)} for {self.dataset_dir}"
-        raw_item = super().load(idx, relative_to_idx)
+        raw_item, metadata = super().load(idx, relative_to_idx)
 
         if with_flow:
-            return self._load_with_flow(raw_item, idx, relative_to_idx)
+            return self._load_with_flow(raw_item, metadata, idx, relative_to_idx)
         else:
-            return self._load_no_flow(raw_item)
+            return self._load_no_flow(raw_item, metadata)
 
-    def load_frame_list(self, relative_to_idx: Optional[int]) -> list[RawItem]:
+    def load_frame_list(self, relative_to_idx: Optional[int]) -> list[TimeSyncedRawItem]:
         return [
             self.load(
                 idx=idx,
@@ -323,12 +348,14 @@ class ArgoverseNoFlowSequence(ArgoverseSceneFlowSequence):
     def _prep_flow(self, flow_dir: Path):
         pass
 
-    def _load_flow(
-        self, idx, classes_0: np.ndarray
-    ) -> tuple[Optional[np.ndarray], Optional[np.ndarray], np.ndarray]:
+    def _load_flow_feather(
+        self, idx: int, classes_0: SemanticClassIdArray
+    ) -> tuple[VectorArray, MaskArray, SemanticClassIdArray]:
         raise NotImplementedError("No flow data available for ArgoverseNoFlowSequence")
 
-    def load(self, idx: int, relative_to_idx: int, with_flow: bool = True) -> SceneFlowItem:
+    def load(
+        self, idx: int, relative_to_idx: int, with_flow: bool = True
+    ) -> tuple[TimeSyncedSceneFlowItem, TimeSyncedAVLidarData]:
         return super().load(idx, relative_to_idx, with_flow=False)
 
 
diff --git a/bucketed_scene_flow_eval/datasets/argoverse2/dataset.py b/bucketed_scene_flow_eval/datasets/argoverse2/dataset.py
index 9de5fa2..82c78be 100644
--- a/bucketed_scene_flow_eval/datasets/argoverse2/dataset.py
+++ b/bucketed_scene_flow_eval/datasets/argoverse2/dataset.py
@@ -5,16 +5,13 @@
 
 import numpy as np
 
-from bucketed_scene_flow_eval.datasets.shared_datastructures import (
-    RawItem,
-    SceneFlowItem,
-)
 from bucketed_scene_flow_eval.datastructures import *
 from bucketed_scene_flow_eval.eval import (
     BucketedEPEEvaluator,
     Evaluator,
     ThreeWayEPEEvaluator,
 )
+from bucketed_scene_flow_eval.interfaces import AbstractDataset
 from bucketed_scene_flow_eval.utils import load_pickle, save_pickle
 
 from .argoverse_scene_flow import (
@@ -31,7 +28,7 @@ class EvalType(enum.Enum):
     THREEWAY_EPE = 1
 
 
-class Argoverse2SceneFlow:
+class Argoverse2SceneFlow(AbstractDataset):
     """
     Wrapper for the Argoverse 2 dataset.
 
@@ -131,126 +128,21 @@ def _av2_sequence_id_and_timestamp_to_idx(self, av2_sequence_id: str, timestamp:
         sequence_idx = sequence._timestamp_to_idx(timestamp)
         return self.sequence_subsequence_idx_to_dataset_idx[(sequence_loader_idx, sequence_idx)]
 
-    def _make_scene_sequence(
-        self, subsequence_frames: list[SceneFlowItem], log_id: str
-    ) -> RawSceneSequence:
-        # Build percept lookup. This stores the percepts for the entire sequence, with the
-        # global frame being zero'd at the target frame.
-        percept_lookup: dict[Timestamp, RawSceneItem] = {}
-        for dataset_idx, entry in enumerate(subsequence_frames):
-            if not self.with_ground:
-                entry.pc.mask = ~entry.is_ground_points
-                entry.flowed_pc.mask = ~entry.is_ground_points
-
-            rgb_frames = RGBFrameLookup.empty()
-            if self.with_rgb:
-                rgb_frames = entry.rgbs
-
-            percept_lookup[dataset_idx] = RawSceneItem(pc_frame=entry.pc, rgb_frames=rgb_frames)
-
-        return RawSceneSequence(percept_lookup, log_id)
-
-    def _make_dummy_query_scene_sequence(
-        self,
-        scene_sequence: RawSceneSequence,
-        subsequence_frames: Sequence[RawItem],
-        subsequence_src_index: int,
-        subsequence_tgt_index: int,
-    ) -> QuerySceneSequence:
-        query_timestamps: list[Timestamp] = [
-            subsequence_src_index,
-            subsequence_tgt_index,
-        ]
-        source_entry = subsequence_frames[subsequence_src_index]
-
-        query_particles = QueryPointLookup(len(source_entry.pc.full_pc), subsequence_src_index)
-
-        return QuerySceneSequence(scene_sequence, query_particles, query_timestamps)
-
-    def _make_query_scene_sequence(
-        self,
-        scene_sequence: RawSceneSequence,
-        subsequence_frames: Sequence[SceneFlowItem],
-        subsequence_src_index: int,
-        subsequence_tgt_index: int,
-    ) -> QuerySceneSequence:
-        query_scene_sequence = self._make_dummy_query_scene_sequence(
-            scene_sequence, subsequence_frames, subsequence_src_index, subsequence_tgt_index
-        )
-
-        source_entry = subsequence_frames[subsequence_src_index]
-        pc_points_array = source_entry.pc.full_global_pc.points
-        is_valid_points_array = source_entry.in_range_mask & source_entry.pc.mask
-
-        # Check that the in_range_points_array is the same size as the first dimension of pc_points_array
-        assert len(is_valid_points_array) == len(
-            pc_points_array
-        ), f"Is valid points and pc points have different lengths. Is valid: {len(is_valid_points_array)}, pc points: {len(pc_points_array)}"
-
-        particle_ids = np.arange(len(is_valid_points_array))
-        query_scene_sequence.query_particles[particle_ids[is_valid_points_array]] = pc_points_array[
-            is_valid_points_array
-        ]
-        return query_scene_sequence
-
-    def _make_results_scene_sequence(
-        self,
-        query: QuerySceneSequence,
-        subsequence_frames: list[SceneFlowItem],
-        subsequence_src_index: int,
-        subsequence_tgt_index: int,
-    ) -> GroundTruthPointFlow:
-        # Build query scene sequence. This requires enumerating all points in
-        # the source frame and the associated flowed points.
+    def _process_with_metadata(
+        self, item: TimeSyncedSceneFlowItem, metadata: TimeSyncedAVLidarData
+    ) -> TimeSyncedSceneFlowItem:
+        # Falsify PC mask for ground points.
+        item.pc.mask = item.pc.mask & metadata.in_range_mask
+        # Falsify Flow mask for ground points.
+        item.flow.mask = item.flow.mask & metadata.in_range_mask
 
-        source_entry = subsequence_frames[subsequence_src_index]
+        if not self.with_ground:
+            item.pc = item.pc.mask_points(~metadata.is_ground_points)
+            item.flow = item.flow.mask_points(~metadata.is_ground_points)
 
-        assert (
-            source_entry.pc.mask == source_entry.flowed_pc.mask
-        ).all(), f"Mask and flowed mask are different."
+        return item
 
-        source_pc = source_entry.pc.full_global_pc.points
-        target_pc = source_entry.flowed_pc.full_global_pc.points
-
-        is_valid_points_array = source_entry.in_range_mask & source_entry.pc.mask
-        pc_class_ids = source_entry.pc_classes
-        assert len(source_pc) == len(
-            target_pc
-        ), "Source and target point clouds must be the same size."
-        assert len(source_pc) == len(
-            pc_class_ids
-        ), f"Source point cloud and class ids must be the same size. Instead got {len(source_pc)} and {len(pc_class_ids)}."
-
-        particle_trajectories = GroundTruthPointFlow(
-            len(source_pc),
-            np.array([subsequence_src_index, subsequence_tgt_index]),
-            query.query_particles.query_init_timestamp,
-            CATEGORY_MAP,
-        )
-
-        points = np.stack([source_pc, target_pc], axis=1)
-
-        particle_ids = np.arange(len(source_pc))
-
-        # is_valids needs to respect the points mask described in the query scene sequence pointcloud.
-        first_timestamp = query.scene_sequence.get_percept_timesteps()[0]
-        is_valids = query.scene_sequence[first_timestamp].pc_frame.mask
-
-        assert len(is_valids) == len(
-            points
-        ), f"Is valids and points have different lengths. Is valids: {len(is_valids)}, points: {len(points)}"
-
-        particle_trajectories[particle_ids[is_valid_points_array]] = (
-            points[is_valid_points_array],
-            pc_class_ids[is_valid_points_array],
-            is_valids[is_valid_points_array],
-        )
-
-        return particle_trajectories
-
-    def __getitem__(
-        self, dataset_idx, verbose: bool = False
-    ) -> tuple[QuerySceneSequence, GroundTruthPointFlow]:
+    def __getitem__(self, dataset_idx, verbose: bool = False) -> list[TimeSyncedSceneFlowItem]:
         if verbose:
             print(f"Argoverse2 Scene Flow dataset __getitem__({dataset_idx}) start")
 
@@ -272,26 +164,15 @@ def __getitem__(
             for i in range(self.subsequence_length)
         ]
 
-        scene_sequence = self._make_scene_sequence(subsequence_frames, sequence.log_id)
-
-        query_scene_sequence = self._make_query_scene_sequence(
-            scene_sequence,
-            subsequence_frames,
-            in_subsequence_src_index,
-            in_subsequence_tgt_index,
-        )
-
-        results_scene_sequence = self._make_results_scene_sequence(
-            query_scene_sequence,
-            subsequence_frames,
-            in_subsequence_src_index,
-            in_subsequence_tgt_index,
-        )
+        scene_flow_items = [item for item, _ in subsequence_frames]
+        scene_flow_metadata = [metadata for _, metadata in subsequence_frames]
 
-        if verbose:
-            print(f"Argoverse2 Scene Flow dataset __getitem__({dataset_idx}) end")
+        scene_flow_items = [
+            self._process_with_metadata(item, metadata)
+            for item, metadata in zip(scene_flow_items, scene_flow_metadata)
+        ]
 
-        return query_scene_sequence, results_scene_sequence
+        return scene_flow_items
 
     def evaluator(self) -> Evaluator:
         eval_args_copy = copy.deepcopy(self.eval_args)
@@ -299,10 +180,14 @@ def evaluator(self) -> Evaluator:
         if self.eval_type == EvalType.BUCKETED_EPE:
             if "meta_class_lookup" not in eval_args_copy:
                 eval_args_copy["meta_class_lookup"] = BUCKETED_METACATAGORIES
+            if "class_id_to_name" not in eval_args_copy:
+                eval_args_copy["class_id_to_name"] = CATEGORY_MAP
             return BucketedEPEEvaluator(**eval_args_copy)
         elif self.eval_type == EvalType.THREEWAY_EPE:
             if "meta_class_lookup" not in eval_args_copy:
                 eval_args_copy["meta_class_lookup"] = THREEWAY_EPE_METACATAGORIES
+            if "class_id_to_name" not in eval_args_copy:
+                eval_args_copy["class_id_to_name"] = CATEGORY_MAP
             return ThreeWayEPEEvaluator(**eval_args_copy)
         else:
             raise ValueError(f"Unknown eval type {self.eval_type}")
diff --git a/bucketed_scene_flow_eval/datasets/nuscenes/nuscenes_loader.py b/bucketed_scene_flow_eval/datasets/nuscenes/nuscenes_loader.py
index aa48f22..64cb54c 100644
--- a/bucketed_scene_flow_eval/datasets/nuscenes/nuscenes_loader.py
+++ b/bucketed_scene_flow_eval/datasets/nuscenes/nuscenes_loader.py
@@ -9,7 +9,7 @@
 from PIL import Image
 from pyquaternion import Quaternion
 
-from bucketed_scene_flow_eval.datasets.shared_datastructures import (
+from bucketed_scene_flow_eval.datasets.interfaces import (
     AbstractSequence,
     AbstractSequenceLoader,
     CachedSequenceLoader,
diff --git a/bucketed_scene_flow_eval/datasets/shared_datastructures/scene_representations.py b/bucketed_scene_flow_eval/datasets/shared_datastructures/scene_representations.py
deleted file mode 100644
index f67b2c4..0000000
--- a/bucketed_scene_flow_eval/datasets/shared_datastructures/scene_representations.py
+++ /dev/null
@@ -1,32 +0,0 @@
-from dataclasses import dataclass
-from typing import Optional
-
-import numpy as np
-
-from bucketed_scene_flow_eval.datastructures import (
-    SE3,
-    CameraProjection,
-    PointCloud,
-    PointCloudFrame,
-    RGBFrame,
-    RGBFrameLookup,
-    RGBImage,
-    Timestamp,
-)
-
-
-@dataclass(kw_only=True)
-class RawItem:
-    pc: PointCloudFrame
-    is_ground_points: np.ndarray
-    in_range_mask: np.ndarray
-    rgbs: RGBFrameLookup
-    log_id: str
-    log_idx: int
-    log_timestamp: Timestamp
-
-
-@dataclass(kw_only=True)
-class SceneFlowItem(RawItem):
-    pc_classes: np.ndarray
-    flowed_pc: PointCloudFrame
diff --git a/bucketed_scene_flow_eval/datasets/waymoopen/dataset.py b/bucketed_scene_flow_eval/datasets/waymoopen/dataset.py
index 618a8a5..4203306 100644
--- a/bucketed_scene_flow_eval/datasets/waymoopen/dataset.py
+++ b/bucketed_scene_flow_eval/datasets/waymoopen/dataset.py
@@ -3,7 +3,7 @@
 
 import numpy as np
 
-from bucketed_scene_flow_eval.datasets.shared_datastructures import SceneFlowItem
+from bucketed_scene_flow_eval.datasets.interfaces import SceneFlowItem
 from bucketed_scene_flow_eval.datastructures import *
 from bucketed_scene_flow_eval.eval import BucketedEPEEvaluator, Evaluator
 from bucketed_scene_flow_eval.utils import load_pickle, save_pickle
@@ -69,20 +69,22 @@ def __len__(self):
         return len(self.dataset_to_sequence_subsequence_idx)
 
     def _make_scene_sequence(
-        self, subsequence_frames: list[SceneFlowItem], seq_id: str
+        self, subsequence_frames: list[TimeSyncedSceneFlowItem], seq_id: str
     ) -> RawSceneSequence:
         # Build percept lookup. This stores the percepts for the entire sequence, with the
         # global frame being zero'd at the target frame.
-        percept_lookup: dict[Timestamp, RawSceneItem] = {}
+        percept_lookup: dict[Timestamp, TimeSyncedRawSceneFrame] = {}
         for dataset_idx, entry in enumerate(subsequence_frames):
-            percept_lookup[dataset_idx] = RawSceneItem(pc_frame=entry.pc, rgb_frames=entry.rgbs)
+            percept_lookup[dataset_idx] = TimeSyncedRawSceneFrame(
+                pc_frame=entry.pc, rgb_frames=entry.rgbs
+            )
 
         return RawSceneSequence(percept_lookup, seq_id)
 
     def _make_query_scene_sequence(
         self,
         scene_sequence: RawSceneSequence,
-        subsequence_frames: list[SceneFlowItem],
+        subsequence_frames: list[TimeSyncedSceneFlowItem],
         subsequence_src_index: int,
         subsequence_tgt_index: int,
     ) -> QuerySceneSequence:
@@ -101,7 +103,7 @@ def _make_query_scene_sequence(
     def _make_results_scene_sequence(
         self,
         query: QuerySceneSequence,
-        subsequence_frames: list[SceneFlowItem],
+        subsequence_frames: list[TimeSyncedSceneFlowItem],
         subsequence_src_index: int,
         subsequence_tgt_index: int,
     ) -> GroundTruthPointFlow:
@@ -110,8 +112,8 @@ def _make_results_scene_sequence(
 
         source_entry = subsequence_frames[subsequence_src_index]
         source_pc = source_entry.pc.global_pc.points
-        target_pc = source_entry.flowed_pc.global_pc.points
-        pc_class_ids = source_entry.pc_classes
+        target_pc = source_entry.flow.global_pc.points
+        pc_class_ids = source_entry.full_pc_classes
         assert len(source_pc) == len(
             target_pc
         ), "Source and target point clouds must be the same size."
diff --git a/bucketed_scene_flow_eval/datasets/waymoopen/waymo_supervised_flow.py b/bucketed_scene_flow_eval/datasets/waymoopen/waymo_supervised_flow.py
index a173675..55c32ac 100644
--- a/bucketed_scene_flow_eval/datasets/waymoopen/waymo_supervised_flow.py
+++ b/bucketed_scene_flow_eval/datasets/waymoopen/waymo_supervised_flow.py
@@ -3,7 +3,7 @@
 
 import numpy as np
 
-from bucketed_scene_flow_eval.datasets.shared_datastructures import (
+from bucketed_scene_flow_eval.datasets.interfaces import (
     AbstractSequence,
     AbstractSequenceLoader,
     CachedSequenceLoader,
diff --git a/bucketed_scene_flow_eval/datastructures/__init__.py b/bucketed_scene_flow_eval/datastructures/__init__.py
index 287378a..7a27e96 100644
--- a/bucketed_scene_flow_eval/datastructures/__init__.py
+++ b/bucketed_scene_flow_eval/datastructures/__init__.py
@@ -1,49 +1,48 @@
 from .camera_projection import CameraModel, CameraProjection
-from .o3d_visualizer import O3DVisualizer
-from .pointcloud import PointCloud, from_fixed_array, to_fixed_array
-from .rgb_image import RGBImage
-from .scene_sequence import (
-    EstimatedPointFlow,
-    GroundTruthPointFlow,
-    ParticleClassId,
-    ParticleID,
+from .dataclasses import (
+    EgoLidarFlow,
+    MaskArray,
     PointCloudFrame,
     PoseInfo,
-    QueryPointLookup,
-    QuerySceneSequence,
-    RawSceneItem,
-    RawSceneSequence,
     RGBFrame,
     RGBFrameLookup,
-    Timestamp,
-    WorldParticle,
+    SemanticClassId,
+    SemanticClassIdArray,
+    SupervisedPointCloudFrame,
+    TimeSyncedAVLidarData,
+    TimeSyncedBaseAuxilaryData,
+    TimeSyncedRawItem,
+    TimeSyncedSceneFlowItem,
+    VectorArray,
 )
+from .o3d_visualizer import O3DVisualizer
+from .pointcloud import PointCloud, from_fixed_array, to_fixed_array
+from .rgb_image import RGBImage
 from .se2 import SE2
 from .se3 import SE3
 
 __all__ = [
-    "PointCloud",
-    "to_fixed_array",
-    "from_fixed_array",
-    "SE3",
-    "SE2",
-    "RawSceneItem",
-    "RGBImage",
-    "CameraProjection",
     "CameraModel",
-    "RawSceneSequence",
+    "CameraProjection",
+    "EgoLidarFlow",
+    "MaskArray",
     "PointCloudFrame",
+    "PoseInfo",
     "RGBFrame",
     "RGBFrameLookup",
-    "PoseInfo",
-    "QuerySceneSequence",
+    "SemanticClassId",
+    "SemanticClassIdArray",
+    "SupervisedPointCloudFrame",
+    "TimeSyncedAVLidarData",
+    "TimeSyncedBaseAuxilaryData",
+    "TimeSyncedRawItem",
+    "TimeSyncedSceneFlowItem",
+    "VectorArray",
     "O3DVisualizer",
-    "ParticleID",
-    "ParticleClassId",
-    "Timestamp",
-    "Timestamp",
-    "WorldParticle",
-    "QueryPointLookup",
-    "GroundTruthPointFlow",
-    "EstimatedPointFlow",
+    "PointCloud",
+    "from_fixed_array",
+    "to_fixed_array",
+    "RGBImage",
+    "SE2",
+    "SE3",
 ]
diff --git a/bucketed_scene_flow_eval/datastructures/camera_projection.py b/bucketed_scene_flow_eval/datastructures/camera_projection.py
index bba6403..b7793b9 100644
--- a/bucketed_scene_flow_eval/datastructures/camera_projection.py
+++ b/bucketed_scene_flow_eval/datastructures/camera_projection.py
@@ -30,6 +30,15 @@ def __init__(self, fx: float, fy: float, cx: float, cy: float, camera_model: Cam
     def __repr__(self) -> str:
         return f"CameraProjection(fx={self.fx}, fy={self.fy}, cx={self.cx}, cy={self.cy}, camera_model={self.camera_model})"
 
+    def rescale(self, reduction_factor: int) -> "CameraProjection":
+        return CameraProjection(
+            fx=self.fx / reduction_factor,
+            fy=self.fy / reduction_factor,
+            cx=self.cx / reduction_factor,
+            cy=self.cy / reduction_factor,
+            camera_model=self.camera_model,
+        )
+
     def image_to_image_plane_pc(
         self, image: RGBImage, depth: float = 1.0
     ) -> tuple[PointCloud, np.ndarray]:
diff --git a/bucketed_scene_flow_eval/datastructures/dataclasses.py b/bucketed_scene_flow_eval/datastructures/dataclasses.py
new file mode 100644
index 0000000..f6d8476
--- /dev/null
+++ b/bucketed_scene_flow_eval/datastructures/dataclasses.py
@@ -0,0 +1,231 @@
+from dataclasses import dataclass
+from typing import Optional, Union
+
+import numpy as np
+from numpy._typing import NDArray
+
+from .camera_projection import CameraProjection
+from .pointcloud import PointCloud
+from .rgb_image import RGBImage
+from .se3 import SE3
+
+SemanticClassId = np.int8
+SemanticClassIdArray = NDArray[SemanticClassId]
+MaskArray = NDArray[np.bool_]
+VectorArray = NDArray[np.float32]
+
+
+@dataclass
+class PoseInfo:
+    sensor_to_ego: SE3
+    ego_to_global: SE3
+
+    def __eq__(self, __value: object) -> bool:
+        if not isinstance(__value, PoseInfo):
+            return False
+        return (
+            self.sensor_to_ego == __value.sensor_to_ego
+            and self.ego_to_global == __value.ego_to_global
+        )
+
+    def __repr__(self) -> str:
+        return f"PoseInfo(sensor_to_ego={self.sensor_to_ego}, ego_to_global={self.ego_to_global})"
+
+
+@dataclass
+class PointCloudFrame:
+    full_pc: PointCloud
+    pose: PoseInfo
+    mask: MaskArray
+
+    @property
+    def pc(self) -> PointCloud:
+        return self.full_pc.mask_points(self.mask)
+
+    @property
+    def full_global_pc(self) -> PointCloud:
+        pose = self.global_pose
+        return self.full_pc.transform(pose)
+
+    @property
+    def global_pc(self) -> PointCloud:
+        pose = self.global_pose
+        return self.pc.transform(pose)
+
+    @property
+    def global_pose(self) -> SE3:
+        return self.pose.ego_to_global @ self.pose.sensor_to_ego
+
+    def mask_points(self, mask: MaskArray) -> "PointCloudFrame":
+        assert isinstance(mask, np.ndarray), f"mask must be an ndarray, got {type(mask)}"
+        assert mask.ndim == 1, f"mask must be a 1D array, got {mask.ndim}"
+        assert mask.dtype == bool, f"mask must be a boolean array, got {mask.dtype}"
+        return PointCloudFrame(
+            full_pc=self.full_pc.mask_points(mask),
+            pose=self.pose,
+            mask=self.mask[mask],
+        )
+
+
+@dataclass
+class SupervisedPointCloudFrame(PointCloudFrame):
+    full_pc_classes: SemanticClassIdArray
+
+    def __post_init__(self):
+        # Check pc_classes
+        assert isinstance(
+            self.full_pc_classes, np.ndarray
+        ), f"pc_classes must be an ndarray, got {type(self.full_pc_classes)}"
+        assert (
+            self.full_pc_classes.ndim == 1
+        ), f"pc_classes must be a 1D array, got {self.full_pc_classes.ndim}"
+        assert (
+            self.full_pc_classes.dtype == SemanticClassId
+        ), f"pc_classes must be a SemanticClassId array, got {self.full_pc_classes.dtype}"
+        assert len(self.full_pc_classes) == len(
+            self.full_pc
+        ), f"pc_classes must be the same length as pc, got {len(self.full_pc_classes)} and {len(self.full_pc)}"
+
+    @property
+    def pc_classes(self) -> SemanticClassIdArray:
+        return self.full_pc_classes[self.mask]
+
+    def mask_points(self, mask: MaskArray) -> "SupervisedPointCloudFrame":
+        assert isinstance(mask, np.ndarray), f"mask must be an ndarray, got {type(mask)}"
+        assert mask.ndim == 1, f"mask must be a 1D array, got {mask.ndim}"
+        assert mask.dtype == bool, f"mask must be a boolean array, got {mask.dtype}"
+        return SupervisedPointCloudFrame(
+            full_pc=self.full_pc.mask_points(mask),
+            pose=self.pose,
+            mask=self.mask[mask],
+            full_pc_classes=self.full_pc_classes[mask],
+        )
+
+
+@dataclass
+class RGBFrame:
+    rgb: RGBImage
+    pose: PoseInfo
+    camera_projection: CameraProjection
+
+    def __repr__(self) -> str:
+        return f"RGBFrame(rgb={self.rgb},\npose={self.pose},\ncamera_projection={self.camera_projection})"
+
+    def rescale(self, factor: int) -> "RGBFrame":
+        return RGBFrame(
+            rgb=self.rgb.rescale(factor),
+            pose=self.pose,
+            camera_projection=self.camera_projection.rescale(factor),
+        )
+
+
+@dataclass
+class RGBFrameLookup:
+    lookup: dict[str, RGBFrame]
+    entries: list[str]
+
+    @staticmethod
+    def empty() -> "RGBFrameLookup":
+        return RGBFrameLookup({}, [])
+
+    def __contains__(self, key: str) -> bool:
+        return key in self.lookup
+
+    def items(self) -> list[tuple[str, RGBFrame]]:
+        return [(key, self.lookup[key]) for key in self.entries]
+
+    def values(self) -> list[RGBFrame]:
+        return [self.lookup[key] for key in self.entries]
+
+    def __getitem__(self, key: str) -> RGBFrame:
+        return self.lookup[key]
+
+    def __len__(self) -> int:
+        return len(self.lookup)
+
+
+@dataclass
+class EgoLidarFlow:
+    """
+    Ego frame lidar flow from the ego frame of P0 to the relative frame of P1.
+    """
+
+    full_flow: VectorArray
+    mask: MaskArray
+
+    @staticmethod
+    def make_no_flow(flow_dim: int) -> "EgoLidarFlow":
+        return EgoLidarFlow(
+            full_flow=np.zeros((flow_dim, 3), dtype=np.float32),
+            mask=np.zeros(flow_dim, dtype=bool),
+        )
+
+    def __post_init__(self):
+        assert self.full_flow.ndim == 2, f"flow must be a 2D array, got {self.full_flow.ndim}"
+        assert self.mask.ndim == 1, f"valid_flow_mask must be a 1D array, got {self.mask.ndim}"
+        assert self.mask.dtype == bool, f"valid_flow_mask must be boolean, got {self.mask.dtype}"
+
+        assert len(self.full_flow) == len(self.mask), (
+            f"flow and valid_flow_mask must have the same length, got {len(self.full_flow)} and "
+            f"{len(self.mask)}"
+        )
+
+        assert (
+            self.full_flow.shape[1] == 3
+        ), f"flow must have 3 columns, got {self.full_flow.shape[1]}"
+
+    def __repr__(self) -> str:
+        return f"LidarFlow(flow={self.full_flow}, valid_flow_mask={self.mask})"
+
+    @property
+    def valid_flow(self) -> VectorArray:
+        return self.full_flow[self.mask]
+
+    @property
+    def shape(self) -> tuple[int, int]:
+        return self.full_flow.shape
+
+    def mask_points(self, mask: MaskArray) -> "EgoLidarFlow":
+        assert isinstance(mask, np.ndarray), f"mask must be an ndarray, got {type(mask)}"
+        assert mask.ndim == 1, f"mask must be a 1D array, got {mask.ndim}"
+        assert mask.dtype == bool, f"mask must be a boolean array, got {mask.dtype}"
+        return EgoLidarFlow(full_flow=self.full_flow[mask], mask=self.mask[mask])
+
+
+@dataclass(kw_only=True)
+class TimeSyncedBaseAuxilaryData:
+    pass
+
+
+@dataclass(kw_only=True)
+class TimeSyncedAVLidarData(TimeSyncedBaseAuxilaryData):
+    is_ground_points: MaskArray
+    in_range_mask: MaskArray
+
+
+@dataclass(kw_only=True)
+class TimeSyncedRawItem:
+    pc: PointCloudFrame
+    rgbs: RGBFrameLookup
+    log_id: str
+    log_idx: int
+    log_timestamp: int
+
+
+@dataclass(kw_only=True)
+class TimeSyncedSceneFlowItem(TimeSyncedRawItem):
+    pc: SupervisedPointCloudFrame
+    flow: EgoLidarFlow
+
+    def __post_init__(self):
+        assert isinstance(
+            self.pc, SupervisedPointCloudFrame
+        ), f"pc must be a SupervisedPointCloudFrame, got {type(self.pc)}"
+        assert isinstance(
+            self.flow, EgoLidarFlow
+        ), f"flow must be an EgoLidarFlow, got {type(self.flow)}"
+
+        # Ensure full flow is the same shape as full pc
+        assert len(self.flow.full_flow) == len(
+            self.pc.full_pc
+        ), f"flow and pc must have the same length, got {len(self.flow.full_flow)} and {len(self.pc.full_pc)}"
diff --git a/bucketed_scene_flow_eval/datastructures/o3d_visualizer.py b/bucketed_scene_flow_eval/datastructures/o3d_visualizer.py
index 96d242e..f61c926 100644
--- a/bucketed_scene_flow_eval/datastructures/o3d_visualizer.py
+++ b/bucketed_scene_flow_eval/datastructures/o3d_visualizer.py
@@ -3,6 +3,7 @@
 import numpy as np
 import open3d as o3d
 
+from .dataclasses import PointCloudFrame, RGBFrame
 from .pointcloud import PointCloud
 from .se3 import SE3
 
@@ -24,13 +25,22 @@ def add_geometry(self, geometry):
         else:
             self.geometry_list.append(geometry)
 
-    def add_pc_frame(
+    def add_global_pc_frame(
         self,
-        pc_frame: "PointCloudFrame",
+        pc_frame: PointCloudFrame,
         color: Union[tuple[float, float, float], None] = None,
     ):
         self.add_pointcloud(pc_frame.global_pc, color=color)
 
+    def add_global_rgb_frame(self, rgb_frame: RGBFrame):
+        image_plane_pc, colors = rgb_frame.camera_projection.image_to_image_plane_pc(
+            rgb_frame.rgb, depth=20
+        )
+        image_plane_pc = image_plane_pc.transform(
+            rgb_frame.pose.ego_to_global.compose(rgb_frame.pose.sensor_to_ego.inverse())
+        )
+        self.add_pointcloud(image_plane_pc, color=colors)
+
     def add_pointcloud(
         self,
         pc: PointCloud,
@@ -40,17 +50,17 @@ def add_pointcloud(
         ] = None,
     ):
         pc = pc.transform(pose)
-        pc = pc.to_o3d()
+        pc_o3d = pc.to_o3d()
         if color is not None:
             color = np.array(color)
             if color.ndim == 1:
-                pc = pc.paint_uniform_color(color)
+                pc_o3d = pc_o3d.paint_uniform_color(color)
             elif color.ndim == 2:
                 assert len(color) == len(
-                    pc.points
-                ), f"Expected color to have length {len(pc.points)}, got {len(color)} instead"
-                pc.colors = o3d.utility.Vector3dVector(color)
-        self.add_geometry(pc)
+                    pc_o3d.points
+                ), f"Expected color to have length {len(pc_o3d.points)}, got {len(color)} instead"
+                pc_o3d.colors = o3d.utility.Vector3dVector(color)
+        self.add_geometry(pc_o3d)
 
     def add_sphere(self, location: np.ndarray, radius: float, color: tuple[float, float, float]):
         sphere = o3d.geometry.TriangleMesh.create_sphere(radius=radius, resolution=2)
@@ -146,13 +156,3 @@ def run(self):
             vis.add_geometry(geometry)
 
         vis.run()
-
-        #
-        # o3d.visualization.draw_geometries(self.geometry_list)
-        # ctr = self.vis.get_view_control()
-        # # Set forward direction to be -X
-        # ctr.set_front([-1, 0, 0])
-        # # Set up direction to be +Z
-        # ctr.set_up([0, 0, 1])
-        # # Set lookat to be origin
-        # ctr.set_lookat([0, 0, 0])
diff --git a/bucketed_scene_flow_eval/datastructures/pointcloud.py b/bucketed_scene_flow_eval/datastructures/pointcloud.py
index c51a1e5..d6a6b23 100644
--- a/bucketed_scene_flow_eval/datastructures/pointcloud.py
+++ b/bucketed_scene_flow_eval/datastructures/pointcloud.py
@@ -186,7 +186,7 @@ def within_region(self, x_min, x_max, y_min, y_max, z_min, z_max) -> "PointCloud
         return self.mask_points(mask)
 
     @property
-    def shape(self) -> tuple:
+    def shape(self) -> tuple[int, int]:
         return self.points.shape
 
     def to_o3d(self) -> o3d.geometry.PointCloud:
diff --git a/bucketed_scene_flow_eval/datastructures/scene_sequence.py b/bucketed_scene_flow_eval/datastructures/scene_sequence.py
deleted file mode 100644
index 5dda8ae..0000000
--- a/bucketed_scene_flow_eval/datastructures/scene_sequence.py
+++ /dev/null
@@ -1,427 +0,0 @@
-from dataclasses import dataclass
-from typing import Optional, Union
-
-import numpy as np
-from numpy._typing import NDArray
-
-from .camera_projection import CameraProjection
-from .o3d_visualizer import O3DVisualizer
-from .pointcloud import PointCloud
-from .rgb_image import RGBImage
-from .se3 import SE3
-
-# Type alias for particle IDs
-ParticleID = int
-ParticleClassId = int
-Timestamp = int
-
-# Type alias for world points
-WorldParticle = np.ndarray
-
-
-@dataclass
-class PoseInfo:
-    sensor_to_ego: SE3
-    ego_to_global: SE3
-
-    def __eq__(self, __value: object) -> bool:
-        if not isinstance(__value, PoseInfo):
-            return False
-        return (
-            self.sensor_to_ego == __value.sensor_to_ego
-            and self.ego_to_global == __value.ego_to_global
-        )
-
-    def __repr__(self) -> str:
-        return f"PoseInfo(sensor_to_ego={self.sensor_to_ego}, ego_to_global={self.ego_to_global})"
-
-
-@dataclass
-class PointCloudFrame:
-    full_pc: PointCloud
-    pose: PoseInfo
-    mask: NDArray
-
-    @property
-    def pc(self) -> PointCloud:
-        return self.full_pc.mask_points(self.mask)
-
-    @property
-    def full_global_pc(self) -> PointCloud:
-        pose = self.global_pose
-        return self.full_pc.transform(pose)
-
-    @property
-    def global_pc(self) -> PointCloud:
-        pose = self.global_pose
-        return self.pc.transform(pose)
-
-    @property
-    def global_pose(self) -> SE3:
-        return self.pose.ego_to_global @ self.pose.sensor_to_ego
-
-    def add_global_flow(self, flow: NDArray, valid_flow_mask: NDArray) -> "PointCloudFrame":
-        assert flow.ndim == 2, f"flow must be a 2D array, got {flow.ndim}"
-        assert (
-            valid_flow_mask.ndim == 1
-        ), f"valid_flow_mask must be a 1D array, got {valid_flow_mask.ndim}"
-        assert (
-            valid_flow_mask.dtype == bool
-        ), f"valid_flow_mask must be boolean, got {valid_flow_mask.dtype}"
-
-        assert len(flow) == len(valid_flow_mask), (
-            f"flow and valid_flow_mask must have the same length, got {len(flow)} and "
-            f"{len(valid_flow_mask)}"
-        )
-
-        assert len(flow) == len(
-            self.full_pc
-        ), f"flow shape {flow.shape} must match point cloud shape {len(self.full_pc)}"
-
-        assert self.mask.shape == valid_flow_mask.shape, (
-            f"mask and valid_flow_mask must have the same length, got {len(self.mask)} and "
-            f"{len(valid_flow_mask)}"
-        )
-
-        # Convert to global pc, add flow, then convert back to ego frame
-        flowed_ego_pc = self.full_global_pc.flow_masked(
-            flow[valid_flow_mask], valid_flow_mask
-        ).transform(self.global_pose.inverse())
-        # Only include points that are valid and in the mask
-        joined_mask = self.mask & valid_flow_mask
-
-        return PointCloudFrame(
-            full_pc=flowed_ego_pc,
-            pose=self.pose,
-            mask=joined_mask,
-        )
-
-
-@dataclass
-class RGBFrame:
-    rgb: RGBImage
-    pose: PoseInfo
-    camera_projection: CameraProjection
-
-    def __repr__(self) -> str:
-        return f"RGBFrame(rgb={self.rgb},\npose={self.pose},\ncamera_projection={self.camera_projection})"
-
-
-@dataclass
-class RGBFrameLookup:
-    lookup: dict[str, RGBFrame]
-    entries: list[str]
-
-    @staticmethod
-    def empty() -> "RGBFrameLookup":
-        return RGBFrameLookup({}, [])
-
-    def __contains__(self, key: str) -> bool:
-        return key in self.lookup
-
-    def items(self) -> list[tuple[str, RGBFrame]]:
-        return [(key, self.lookup[key]) for key in self.entries]
-
-    def values(self) -> list[RGBFrame]:
-        return [self.lookup[key] for key in self.entries]
-
-    def __getitem__(self, key: str) -> RGBFrame:
-        return self.lookup[key]
-
-    def __len__(self) -> int:
-        return len(self.lookup)
-
-
-@dataclass
-class RawSceneItem:
-    pc_frame: PointCloudFrame
-    rgb_frames: RGBFrameLookup
-
-
-def _particle_id_to_color(particle_id: ParticleID) -> NDArray:
-    particle_id = int(particle_id)
-    assert isinstance(
-        particle_id, ParticleID
-    ), f"particle_id must be a ParticleID ({ParticleID}) , got {type(particle_id)}"
-    hash_val = abs(hash(particle_id)) % (256**3)
-    return np.array(
-        [
-            ((hash_val >> 16) & 0xFF) / 255,
-            ((hash_val >> 8) & 0xFF) / 255,
-            (hash_val & 0xFF) / 255,
-        ]
-    )
-
-
-class RawSceneSequence:
-    """
-    This class contains only the raw percepts from a sequence. Its goal is to
-    describe the scene as it is observed by the sensors; it does not contain
-    any other information such as point position descriptions.
-
-    These percept modalities are:
-        - RGB
-        - PointClouds
-
-    Additionally, we store frame conversions for each percept.
-    """
-
-    def __init__(self, percept_lookup: dict[Timestamp, RawSceneItem], log_id: str):
-        assert isinstance(
-            percept_lookup, dict
-        ), f"percept_lookup must be a dict, got {type(percept_lookup)}"
-        assert all(
-            isinstance(key, Timestamp) for key in percept_lookup.keys()
-        ), f"percept_lookup keys must be Timestamp, got {[type(key) for key in percept_lookup.keys()]}"
-        assert all(
-            isinstance(value, RawSceneItem) for value in percept_lookup.values()
-        ), f"percept_lookup values must be RawSceneItem, got {[type(value) for value in percept_lookup.values()]}"
-        self.percept_lookup = percept_lookup
-        self.log_id = log_id
-
-    def get_percept_timesteps(self) -> list[int]:
-        return sorted(self.percept_lookup.keys())
-
-    def __len__(self):
-        return len(self.get_percept_timesteps())
-
-    def __getitem__(self, timestamp: int) -> RawSceneItem:
-        assert isinstance(timestamp, int), f"timestamp must be an int, got {type(timestamp)}"
-        return self.percept_lookup[timestamp]
-
-    def visualize(self, vis: O3DVisualizer) -> O3DVisualizer:
-        timesteps = self.get_percept_timesteps()
-        grayscale_color = np.linspace(0, 1, len(timesteps) + 1)
-        for idx, timestamp in enumerate(timesteps):
-            item: RawSceneItem = self[timestamp]
-
-            vis.add_pc_frame(item.pc_frame, color=[grayscale_color[idx]] * 3)
-            vis.add_pose(item.pc_frame.global_pose)
-        return vis
-
-    def __eq__(self, __value: object) -> bool:
-        if not isinstance(__value, RawSceneSequence):
-            return False
-        return self.percept_lookup == __value.percept_lookup and self.log_id == __value.log_id
-
-
-class QueryPointLookup:
-    """
-    This class is an efficient lookup table for query points.
-    """
-
-    def __init__(self, num_entries: int, query_init_timestamp: Timestamp):
-        self.num_entries = num_entries
-        self.query_init_world_points = np.zeros((num_entries, 3), dtype=np.float32)
-        self.query_init_timestamp = query_init_timestamp
-        self.is_valid = np.zeros((num_entries,), dtype=bool)
-
-    def __len__(self) -> int:
-        return self.is_valid.sum()
-
-    def __getitem__(self, particle_id: ParticleID) -> tuple[WorldParticle, Timestamp]:
-        assert (
-            particle_id < self.num_entries
-        ), f"particle_id {particle_id} must be less than {self.num_entries}"
-        return self.query_init_world_points[particle_id], self.query_init_timestamp
-
-    def __setitem__(self, particle_id_arr: np.ndarray, value: WorldParticle):
-        assert (
-            particle_id_arr < self.num_entries
-        ).all(), f"particle_id value must be less than {self.num_entries}"
-        self.query_init_world_points[particle_id_arr] = value
-        self.is_valid[particle_id_arr] = True
-
-    @property
-    def particle_ids(self) -> NDArray:
-        return np.arange(self.num_entries)[self.is_valid]
-
-    def valid_query_init_world_points(self) -> NDArray:
-        return self.query_init_world_points[self.is_valid]
-
-
-class QuerySceneSequence:
-    """
-    This class describes a scene sequence with a query for motion descriptions.
-
-    A query is a point + timestamp in the global frame of the scene, along with
-    series of timestamps for which a point description is requested; motion is
-    implied to be linear between these points at the requested timestamps.
-    """
-
-    def __init__(
-        self,
-        scene_sequence: RawSceneSequence,
-        query_points: QueryPointLookup,
-        query_flow_timestamps: list[Timestamp],
-    ):
-        assert isinstance(
-            scene_sequence, RawSceneSequence
-        ), f"scene_sequence must be a RawSceneSequence, got {type(scene_sequence)}"
-        assert isinstance(
-            query_points, QueryPointLookup
-        ), f"query_particles must be a dict, got {type(query_points)}"
-        assert isinstance(
-            query_flow_timestamps, list
-        ), f"query_timestamps must be a list, got {type(query_flow_timestamps)}"
-
-        self.scene_sequence = scene_sequence
-
-        ###################################################
-        # Sanity checks to ensure that the query is valid #
-        ###################################################
-
-        # Check that the query timestamps all have corresponding percepts
-        assert set(query_flow_timestamps).issubset(
-            set(self.scene_sequence.get_percept_timesteps())
-        ), f"Query timestamps {query_flow_timestamps} must be a subset of the scene sequence percepts {self.scene_sequence.get_percept_timesteps()}"
-
-        self.query_flow_timestamps = query_flow_timestamps
-        self.query_particles = query_points
-
-    def __len__(self) -> int:
-        return len(self.query_flow_timestamps)
-
-    def visualize(
-        self,
-        vis: O3DVisualizer,
-        percent_subsample: Union[None, float] = None,
-        verbose=False,
-    ) -> O3DVisualizer:
-        if percent_subsample is not None:
-            assert (
-                percent_subsample > 0 and percent_subsample <= 1
-            ), f"percent_subsample must be in (0, 1], got {percent_subsample}"
-            every_kth_particle = int(1 / percent_subsample)
-        else:
-            every_kth_particle = 1
-        # Visualize the query points ordered by particle ID
-        particle_ids = self.query_particles.particle_ids
-        world_particles = self.query_particles.valid_query_init_world_points()
-
-        kth_particle_ids = particle_ids[::every_kth_particle]
-        kth_world_particles = world_particles[::every_kth_particle]
-
-        assert len(kth_particle_ids) == len(
-            kth_world_particles
-        ), f"Expected kth_particle_ids and kth_world_particles to have the same length, got {len(kth_particle_ids)} and {len(kth_world_particles)} instead"
-
-        kth_particle_colors = [
-            _particle_id_to_color(particle_id) for particle_id in kth_particle_ids
-        ]
-        assert len(kth_particle_colors) == len(
-            kth_particle_ids
-        ), f"Expected kth_particle_colors and kth_particle_ids to have the same length, got {len(kth_particle_colors)} and {len(kth_particle_ids)} instead"
-
-        vis.add_spheres(kth_world_particles, 0.1, kth_particle_colors)
-        return vis
-
-
-class EstimatedPointFlow:
-    def __init__(
-        self,
-        num_entries: int,
-        trajectory_timestamps: Union[list[Timestamp], np.ndarray],
-    ):
-        self.num_entries = num_entries
-
-        if isinstance(trajectory_timestamps, list):
-            trajectory_timestamps = np.array(trajectory_timestamps)
-
-        assert (
-            trajectory_timestamps.ndim == 1
-        ), f"trajectory_timestamps must be a 1D array, got {trajectory_timestamps.ndim}"
-        self.trajectory_timestamps = trajectory_timestamps
-        self.trajectory_length = len(trajectory_timestamps)
-
-        self.world_points = np.zeros((num_entries, self.trajectory_length, 3), dtype=np.float32)
-
-        # By default, all trajectories are invalid
-        self.is_valid_flow = np.zeros((num_entries,), dtype=bool)
-
-    def valid_particle_ids(self) -> NDArray:
-        return np.arange(self.num_entries)[self.is_valid_flow]
-
-    def get_flow(
-        self, src_timestamp: Timestamp, target_timestamp: Timestamp
-    ) -> tuple[NDArray, NDArray]:
-        src_idx = np.where(self.trajectory_timestamps == src_timestamp)[0][0]
-        target_idx = np.where(self.trajectory_timestamps == target_timestamp)[0][0]
-        return (
-            self.world_points[:, target_idx] - self.world_points[:, src_idx],
-            self.is_valid_flow,
-        )
-
-    def __len__(self) -> int:
-        return self.is_valid_flow.sum()
-
-    def __setitem__(self, particle_id: ParticleID, points: NDArray):
-        self.world_points[particle_id] = points
-        self.is_valid_flow[particle_id] = True
-
-    def visualize(
-        self,
-        vis: O3DVisualizer,
-        percent_subsample: Union[None, float] = None,
-        verbose: bool = False,
-    ) -> O3DVisualizer:
-        if percent_subsample is not None:
-            assert (
-                percent_subsample > 0 and percent_subsample <= 1
-            ), f"percent_subsample must be in (0, 1], got {percent_subsample}"
-            every_kth_particle = int(1 / percent_subsample)
-        else:
-            every_kth_particle = 1
-
-        # Shape: points, 2, 3
-        world_points = self.world_points.copy()
-        world_points = world_points[self.is_valid_flow]
-        world_points = world_points[::every_kth_particle]
-
-        vis.add_trajectories(world_points)
-        return vis
-
-
-class GroundTruthPointFlow(EstimatedPointFlow):
-    def __init__(
-        self,
-        num_entries: int,
-        trajectory_timestamps: Union[list[Timestamp], np.ndarray],
-        query_timestamp: int,
-        class_name_map: Optional[dict[ParticleClassId, str]] = None,
-    ):
-        super().__init__(num_entries, trajectory_timestamps)
-        self.class_name_map = class_name_map
-        self.cls_ids = np.zeros((num_entries,), dtype=np.int64)
-        self.query_timestamp = query_timestamp
-        assert (
-            self.query_timestamp in self.trajectory_timestamps
-        ), f"query_timestamp {self.query_timestamp} must be in trajectory_timestamps {self.trajectory_timestamps}"
-
-    def _mask_entries(self, mask: np.ndarray):
-        assert mask.ndim == 1, f"mask must be a 1D array, got {mask.ndim}"
-
-        assert (
-            len(mask) == self.num_entries
-        ), f"mask must be the same length as the number of entries, got {len(mask)} and {self.num_entries} instead"
-
-        self.is_valid_flow[~mask] = False
-
-    def __setitem__(
-        self,
-        particle_id: ParticleID,
-        data_tuple: tuple[NDArray, ParticleClassId, NDArray],
-    ):
-        points, cls_ids, is_valids = data_tuple
-        self.world_points[particle_id] = points
-        self.cls_ids[particle_id] = cls_ids
-        self.is_valid_flow[particle_id] = is_valids
-
-    def pretty_name(self, class_id: ParticleClassId) -> str:
-        if self.class_name_map is None:
-            return str(class_id)
-
-        if class_id not in self.class_name_map:
-            return str(class_id)
-
-        return self.class_name_map[class_id]
diff --git a/bucketed_scene_flow_eval/eval/base_per_frame_sceneflow_eval.py b/bucketed_scene_flow_eval/eval/base_per_frame_sceneflow_eval.py
index c792850..f80a4f5 100644
--- a/bucketed_scene_flow_eval/eval/base_per_frame_sceneflow_eval.py
+++ b/bucketed_scene_flow_eval/eval/base_per_frame_sceneflow_eval.py
@@ -1,14 +1,17 @@
 import copy
 from dataclasses import dataclass
 from pathlib import Path
-from typing import Any, Set, Union
+from typing import Any, Iterable, Set, Union
 
 import numpy as np
 
 from bucketed_scene_flow_eval.datastructures import (
-    EstimatedPointFlow,
-    GroundTruthPointFlow,
-    Timestamp,
+    EgoLidarFlow,
+    PointCloud,
+    SemanticClassId,
+    SemanticClassIdArray,
+    TimeSyncedSceneFlowItem,
+    VectorArray,
 )
 from bucketed_scene_flow_eval.utils import save_json, save_pickle
 
@@ -40,18 +43,17 @@ def __eq__(self, __value: object) -> bool:
 class BaseEvalFrameResult:
     def __init__(
         self,
-        gt_world_points: np.ndarray,
-        gt_class_ids: np.ndarray,
-        gt_flow: np.ndarray,
-        pred_flow: np.ndarray,
-        class_id_to_name=lambda e: e,
+        gt_world_points: PointCloud,
+        gt_class_ids: SemanticClassIdArray,
+        gt_flow: VectorArray,
+        pred_flow: VectorArray,
+        class_id_to_name: dict[SemanticClassId, str],
         distance_thresholds: list[float] = [35, np.inf],
         max_speed_thresholds: list[tuple[float, float]] = [(0, np.inf)],
     ):
         self.distance_thresholds = distance_thresholds
         self.max_speed_thresholds = max_speed_thresholds
 
-        assert gt_world_points.ndim == 2, f"gt_world_points must be 3D, got {gt_world_points.ndim}"
         assert (
             gt_world_points.shape == gt_flow.shape
         ), f"gt_world_points and gt_flow must have the same shape, got {gt_world_points.shape} and {gt_flow.shape}"
@@ -60,8 +62,6 @@ def __init__(
             gt_flow.shape == pred_flow.shape
         ), f"gt_flow and pred_flow must have the same shape, got {gt_flow.shape} and {pred_flow.shape}"
 
-        gt_speeds = np.linalg.norm(gt_flow, axis=1)
-
         scaled_gt_flow, scaled_pred_flow = self._scale_flows(gt_flow, pred_flow)
 
         scaled_epe_errors = np.linalg.norm(scaled_gt_flow - scaled_pred_flow, axis=1)
@@ -70,14 +70,14 @@ def __init__(
             k: v
             for k, v in self.make_splits(
                 gt_world_points,
-                gt_speeds,
+                gt_flow,
                 gt_class_ids,
                 scaled_epe_errors,
                 class_id_to_name,
             )
         }
 
-    def _get_gt_classes(self, gt_class_ids: np.ndarray) -> Set[int]:
+    def _get_gt_classes(self, gt_class_ids: SemanticClassIdArray) -> SemanticClassIdArray:
         return np.unique(gt_class_ids)
 
     def _get_distance_thresholds(self) -> list[float]:
@@ -92,8 +92,14 @@ def _scale_flows(
         return gt_flow, pred_flow
 
     def make_splits(
-        self, gt_world_points, gt_speeds, gt_class_ids, epe_errors, class_id_to_name
-    ) -> list[tuple[BaseSplitKey, BaseSplitValue]]:
+        self,
+        gt_world_points: PointCloud,
+        gt_flow: VectorArray,
+        gt_class_ids: SemanticClassIdArray,
+        epe_errors: np.ndarray,
+        class_id_to_name: dict[SemanticClassId, str],
+    ) -> Iterable[tuple[BaseSplitKey, BaseSplitValue]]:
+        gt_speeds = np.linalg.norm(gt_flow, axis=1)
         unique_gt_classes = self._get_gt_classes(gt_class_ids)
         distance_thresholds = self._get_distance_thresholds()
         speed_threshold_tuples = self._get_max_speed_thresholds()
@@ -111,7 +117,7 @@ def make_splits(
                     continue
                 for distance_threshold in distance_thresholds:
                     within_distance_mask = (
-                        np.linalg.norm(gt_world_points[:, :2], ord=np.inf, axis=1)
+                        np.linalg.norm(gt_world_points.points[:, :2], ord=np.inf, axis=1)
                         < distance_threshold
                     )
 
@@ -123,20 +129,25 @@ def make_splits(
 
                     avg_epe = np.sum(epe_errors[match_mask]) / count
                     split_avg_speed = np.mean(gt_speeds[match_mask])
-                    class_name = class_id_to_name(class_id)
+                    class_name = class_id_to_name[class_id]
                     yield BaseSplitKey(
                         class_name, distance_threshold, speed_threshold_tuple
                     ), BaseSplitValue(avg_epe, count, split_avg_speed)
 
 
 class PerFrameSceneFlowEvaluator(Evaluator):
-    def __init__(self, output_path: Path = Path("/tmp/frame_results")):
+    def __init__(
+        self,
+        class_id_to_name: dict[SemanticClassId, str],
+        output_path: Path = Path("/tmp/frame_results"),
+    ):
         output_path = Path(output_path)
         self.eval_frame_results: list[BaseEvalFrameResult] = []
         self.output_path = output_path
         # print(f"Saving results to {self.output_path}")
         # make the directory if it doesn't exist
         self.output_path.mkdir(parents=True, exist_ok=True)
+        self.class_id_to_name = class_id_to_name
 
     @staticmethod
     def from_evaluator_list(evaluator_list: list["PerFrameSceneFlowEvaluator"]):
@@ -163,157 +174,54 @@ def __add__(self, other: "PerFrameSceneFlowEvaluator"):
     def __len__(self):
         return len(self.eval_frame_results)
 
-    def _validate_inputs(
+    def _sanitize_and_validate_inputs(
         self,
-        predictions: EstimatedPointFlow,
-        ground_truth: GroundTruthPointFlow,
+        predictions: EgoLidarFlow,
+        ground_truth: TimeSyncedSceneFlowItem,
     ):
         assert isinstance(
-            predictions, EstimatedPointFlow
-        ), f"predictions must be a EstimatedParticleTrajectories, got {type(predictions)}"
+            predictions, EgoLidarFlow
+        ), f"predictions must be a EstimatedFlows, got {type(predictions)}"
 
         assert isinstance(
-            ground_truth, GroundTruthPointFlow
-        ), f"ground_truth must be a GroundTruthParticleTrajectories, got {type(ground_truth)}"
-
-        # Validate that the predictions and ground truth have the same underlying size.
-        assert (
-            predictions.num_entries == ground_truth.num_entries
-        ), f"predictions and ground_truth must have the same number of predictions, got {predictions.num_entries} and {ground_truth.num_entries}"
-
-        # Validate that the valid ground truths are the same as the valid predictions (it's OK to have more valid predictions than ground truths).
-        assert (
-            predictions.is_valid_flow.shape == ground_truth.is_valid_flow.shape
-        ), f"predictions and ground_truth must have the same shape, got {predictions.is_valid_flow.shape} and {ground_truth.is_valid_flow.shape}"
-        assert (
-            (predictions.is_valid_flow & ground_truth.is_valid_flow) == ground_truth.is_valid_flow
-        ).all(), f"predictions and ground_truth must have the same valid entries, however some were missing."
+            ground_truth, TimeSyncedSceneFlowItem
+        ), f"ground_truth must be a GroundTruthFlows, got {type(ground_truth)}"
 
-        predictions.is_valid_flow = ground_truth.is_valid_flow
-
-        assert (
-            len(predictions) > 0
-        ), f"predictions must have at least one prediction, got {len(predictions)}"
-
-        # All Ground Truth Particle Trajectories must be in the set of Estimation Particle Trajectories.
-        # It's acceptable for the Estimation Particle Trajectories to have more trajectories than
-        # the Ground Truth Particle Trajectories.
-
-        predictions_intersection_ground_truth = (
-            predictions.is_valid_flow & ground_truth.is_valid_flow
-        )
-        predictions_match_ground_truth = (
-            predictions_intersection_ground_truth == ground_truth.is_valid_flow
-        )
-        vectors = ground_truth.world_points[~predictions_match_ground_truth]
-        assert (
-            predictions_match_ground_truth
-        ).all(), f"all ground truth particle trajectories must be in the estimation particle trajectories. Nonmatching points: {(~predictions_match_ground_truth).sum()}. Violating vectors: {vectors}"
-
-        # All timestamps for the Ground Truth Particle Trajectories must be in the set of Estimation Particle Trajectories.
-        # It's acceptable for the Estimation Particle Trajectories to have more timestamps than
-        # the Ground Truth Particle Trajectories.
-        assert set(ground_truth.trajectory_timestamps).issubset(
-            set(predictions.trajectory_timestamps)
-        ), f"all timestamps for the ground truth particle trajectories must be in the estimation particle trajectories. Nonmatching timestamps: {set(ground_truth.trajectory_timestamps) - set(predictions.trajectory_timestamps)}"
-
-    def _get_indices_of_timestamps(
-        self,
-        predictions: EstimatedPointFlow,
-        ground_truth: GroundTruthPointFlow,
-        query_timestamp: Timestamp,
-    ):
-        # create an numpy array
-        pred_timestamps = predictions.trajectory_timestamps
+        # Ensure that the predictions underlying array is the same shape as the gt
+        assert len(predictions.full_flow) == len(
+            ground_truth.flow.full_flow
+        ), f"predictions and ground_truth must have the same length, got {len(predictions.full_flow)} and {len(ground_truth.flow.full_flow)}"
 
-        traj_timestamps = ground_truth.trajectory_timestamps
+        # Validate that all valid gt flow vectors are considered valid in the predictions.
+        assert np.all(
+            (predictions.mask & ground_truth.flow.mask) == ground_truth.flow.mask
+        ), "All valid gt flow vectors must be considered valid in the predictions"
 
-        # index of first occurrence of each value
-        sorter = np.argsort(pred_timestamps)
+        # Set the prediction valid flow mask to be the gt flow so everything lines up
+        predictions.mask = ground_truth.flow.mask
 
-        matched_idxes = sorter[np.searchsorted(pred_timestamps, traj_timestamps, sorter=sorter)]
+    def eval(self, predicted_flow: EgoLidarFlow, gt_frame: TimeSyncedSceneFlowItem):
+        self._sanitize_and_validate_inputs(predicted_flow, gt_frame)
 
-        # find the index of the query timestamp in traj_timestamps
-        query_idx = np.where(traj_timestamps == query_timestamp)[0][0]
+        is_valid_flow_mask = gt_frame.flow.mask
 
-        return matched_idxes, query_idx
+        global_pc = gt_frame.pc.full_global_pc.mask_points(is_valid_flow_mask)
+        class_ids = gt_frame.pc.full_pc_classes[is_valid_flow_mask]
+        gt_flow = gt_frame.flow.valid_flow
+        pred_flow = predicted_flow.valid_flow
 
-    def eval(
-        self,
-        predictions: EstimatedPointFlow,
-        ground_truth: GroundTruthPointFlow,
-        query_timestamp: Timestamp,
-    ):
-        self._validate_inputs(predictions, ground_truth)
-
-        # Extract the ground truth entires for the timestamps that are in both the predictions and ground truth.
-        # It could be that the predictions have more timestamps than the ground truth.
-
-        matched_time_axis_indices, query_idx = self._get_indices_of_timestamps(
-            predictions, ground_truth, query_timestamp
-        )
-
-        # We only support Scene Flow
-        if query_idx != 0:
-            raise NotImplementedError("TODO: Handle query_idx != 0 when computing speed bucketing.")
-
-        eval_particle_ids = ground_truth.valid_particle_ids()
-
-        gt_is_valids = ground_truth.is_valid_flow[eval_particle_ids]
-
-        pred_is_valids = predictions.is_valid_flow[eval_particle_ids]
-
-        # Make sure that all the pred_is_valids are true if gt_is_valids is true.
-        assert (
-            (gt_is_valids & pred_is_valids) == gt_is_valids
-        ).all(), f"all gt_is_valids must be true if pred_is_valids is true."
-
-        gt_world_points = ground_truth.world_points[eval_particle_ids][:, matched_time_axis_indices]
-        pred_world_points = predictions.world_points[eval_particle_ids][
-            :, matched_time_axis_indices
-        ]
-
-        gt_class_ids = ground_truth.cls_ids[eval_particle_ids]
-
-        assert (
-            gt_world_points.shape[1] == 2
-        ), f"gt_world_points must have 2 timestamps; we only support Scene Flow. Instead we got {gt_world_points.shape[1]} dimensions."
-        assert (
-            pred_world_points.shape[1] == 2
-        ), f"pred_world_points must have 2 timestamps; we only support Scene Flow. Instead we got {pred_world_points.shape[1]} dimensions."
-
-        # Query index should have roughly the same values.
-        assert np.isclose(
-            gt_world_points[:, query_idx], pred_world_points[:, query_idx]
-        ).all(), f"gt_world_points and pred_world_points should have the same values for the query index, got {gt_world_points[:, query_idx]} and {pred_world_points[:, query_idx]}"
-
-        pc1 = gt_world_points[:, 0]
-        gt_pc2 = gt_world_points[:, 1]
-        pred_pc2 = pred_world_points[:, 1]
-
-        gt_flow = gt_pc2 - pc1
-        pred_flow = pred_pc2 - pc1
-
-        eval_frame_result = self._build_eval_frame_results(
-            pc1, gt_class_ids, gt_flow, pred_flow, ground_truth
+        self.eval_frame_results.append(
+            self._build_eval_frame_results(global_pc, class_ids, gt_flow, pred_flow)
         )
 
-        self.eval_frame_results.append(eval_frame_result)
-
     def _build_eval_frame_results(
-        self,
-        pc1: np.ndarray,
-        gt_class_ids: np.ndarray,
-        gt_flow: np.ndarray,
-        pred_flow: np.ndarray,
-        ground_truth: GroundTruthPointFlow,
+        self, pc1: np.ndarray, gt_class_ids: np.ndarray, gt_flow: np.ndarray, pred_flow: np.ndarray
     ) -> BaseEvalFrameResult:
+        """
+        Override this method to build a custom EvalFrameResult child construction
+        """
         return BaseEvalFrameResult(
-            pc1,
-            gt_class_ids,
-            gt_flow,
-            pred_flow,
-            class_id_to_name=ground_truth.pretty_name,
+            pc1, gt_class_ids, gt_flow, pred_flow, class_id_to_name=self.class_id_to_name
         )
 
     def _save_intermediary_results(self):
diff --git a/bucketed_scene_flow_eval/eval/bucketed_epe.py b/bucketed_scene_flow_eval/eval/bucketed_epe.py
index 8525ce5..716f912 100644
--- a/bucketed_scene_flow_eval/eval/bucketed_epe.py
+++ b/bucketed_scene_flow_eval/eval/bucketed_epe.py
@@ -5,6 +5,7 @@
 
 import numpy as np
 
+from bucketed_scene_flow_eval.datastructures import SemanticClassId
 from bucketed_scene_flow_eval.utils import save_json, save_txt
 
 from .base_per_frame_sceneflow_eval import (
@@ -220,6 +221,7 @@ def to_full_latex(self, normalized: bool = True) -> str:
 class BucketedEPEEvaluator(PerFrameSceneFlowEvaluator):
     def __init__(
         self,
+        class_id_to_name: dict[SemanticClassId, str],
         bucket_max_speed: float = 20.0 / 10.0,
         num_buckets: int = 51,
         output_path: Path = Path("/tmp/frame_results/bucketed_epe"),
@@ -230,22 +232,17 @@ def __init__(
         bucket_edges = np.concatenate([np.linspace(0, bucket_max_speed, num_buckets), [np.inf]])
         self.speed_thresholds = list(zip(bucket_edges, bucket_edges[1:]))
         self.meta_class_lookup = meta_class_lookup
-        super().__init__(output_path=output_path)
+        super().__init__(output_path=output_path, class_id_to_name=class_id_to_name)
 
     def _build_eval_frame_results(
-        self,
-        pc1: np.ndarray,
-        gt_class_ids: np.ndarray,
-        gt_flow: np.ndarray,
-        pred_flow: np.ndarray,
-        ground_truth,
+        self, pc1: np.ndarray, gt_class_ids: np.ndarray, gt_flow: np.ndarray, pred_flow: np.ndarray
     ) -> BaseEvalFrameResult:
         return BucketedEvalFrameResult(
             pc1,
             gt_class_ids,
             gt_flow,
             pred_flow,
-            class_id_to_name=ground_truth.pretty_name,
+            class_id_to_name=self.class_id_to_name,
             max_speed_thresholds=self.speed_thresholds,
         )
 
diff --git a/bucketed_scene_flow_eval/eval/eval.py b/bucketed_scene_flow_eval/eval/eval.py
index b26229a..cf976c4 100644
--- a/bucketed_scene_flow_eval/eval/eval.py
+++ b/bucketed_scene_flow_eval/eval/eval.py
@@ -3,20 +3,14 @@
 from typing import Any
 
 from bucketed_scene_flow_eval.datastructures import (
-    EstimatedPointFlow,
-    GroundTruthPointFlow,
-    Timestamp,
+    EgoLidarFlow,
+    TimeSyncedSceneFlowItem,
 )
 
 
 class Evaluator(ABC):
     @abstractmethod
-    def eval(
-        self,
-        predictions: EstimatedPointFlow,
-        ground_truth: GroundTruthPointFlow,
-        query_timestamp: Timestamp,
-    ):
+    def eval(self, predictions: EgoLidarFlow, gt: TimeSyncedSceneFlowItem):
         pass
 
     @abstractmethod
diff --git a/bucketed_scene_flow_eval/datasets/shared_datastructures/__init__.py b/bucketed_scene_flow_eval/interfaces/__init__.py
similarity index 65%
rename from bucketed_scene_flow_eval/datasets/shared_datastructures/__init__.py
rename to bucketed_scene_flow_eval/interfaces/__init__.py
index 32c1139..9ee131e 100644
--- a/bucketed_scene_flow_eval/datasets/shared_datastructures/__init__.py
+++ b/bucketed_scene_flow_eval/interfaces/__init__.py
@@ -1,13 +1,13 @@
+from .abstract_dataset import AbstractDataset
 from .abstract_sequence_loader import (
     AbstractSequence,
     AbstractSequenceLoader,
     CachedSequenceLoader,
 )
-from .scene_representations import RawItem, SceneFlowItem
 
 __all__ = [
+    "AbstractDataset",
     "AbstractSequence",
     "AbstractSequenceLoader",
-    "RawItem",
-    "SceneFlowItem",
+    "CachedSequenceLoader",
 ]
diff --git a/bucketed_scene_flow_eval/interfaces/abstract_dataset.py b/bucketed_scene_flow_eval/interfaces/abstract_dataset.py
new file mode 100644
index 0000000..6dfd34e
--- /dev/null
+++ b/bucketed_scene_flow_eval/interfaces/abstract_dataset.py
@@ -0,0 +1,14 @@
+# import abstract base class
+from abc import ABC, abstractmethod
+
+from bucketed_scene_flow_eval.datastructures import TimeSyncedSceneFlowItem
+
+
+class AbstractDataset:
+    @abstractmethod
+    def __getitem__(self, idx: int) -> list[TimeSyncedSceneFlowItem]:
+        pass
+
+    @abstractmethod
+    def __len__(self) -> int:
+        pass
diff --git a/bucketed_scene_flow_eval/datasets/shared_datastructures/abstract_sequence_loader.py b/bucketed_scene_flow_eval/interfaces/abstract_sequence_loader.py
similarity index 83%
rename from bucketed_scene_flow_eval/datasets/shared_datastructures/abstract_sequence_loader.py
rename to bucketed_scene_flow_eval/interfaces/abstract_sequence_loader.py
index 7ca70f8..e681773 100644
--- a/bucketed_scene_flow_eval/datasets/shared_datastructures/abstract_sequence_loader.py
+++ b/bucketed_scene_flow_eval/interfaces/abstract_sequence_loader.py
@@ -1,7 +1,10 @@
 # import abstract base class
 from abc import ABC, abstractmethod
 
-from .scene_representations import RawItem
+from bucketed_scene_flow_eval.datastructures import (
+    TimeSyncedBaseAuxilaryData,
+    TimeSyncedRawItem,
+)
 
 
 class AbstractSequence(ABC):
@@ -9,7 +12,9 @@ def __init__(self):
         pass
 
     @abstractmethod
-    def load(self, idx: int, relative_to_idx: int) -> RawItem:
+    def load(
+        self, idx: int, relative_to_idx: int
+    ) -> tuple[TimeSyncedRawItem, TimeSyncedBaseAuxilaryData]:
         pass
 
     @abstractmethod
diff --git a/scripts/demo.py b/scripts/demo.py
deleted file mode 100644
index ddf847e..0000000
--- a/scripts/demo.py
+++ /dev/null
@@ -1,45 +0,0 @@
-import argparse
-from pathlib import Path
-from typing import Optional
-
-import numpy as np
-from matplotlib import pyplot as plt
-
-from bucketed_scene_flow_eval.datasets import construct_dataset
-from bucketed_scene_flow_eval.datastructures import O3DVisualizer, QuerySceneSequence
-
-
-def visualize_lidar_3d(query: QuerySceneSequence):
-    scene_timestamp = query.query_particles.query_init_timestamp
-
-    rgb_frames = query.scene_sequence[scene_timestamp].rgb_frames
-    pc_frame = query.scene_sequence[scene_timestamp].pc_frame
-
-    o3d_vis = O3DVisualizer()
-    o3d_vis.add_pointcloud(pc_frame.global_pc)
-    for rgb_frame in rgb_frames.values():
-        image_plane_pc, colors = rgb_frame.camera_projection.image_to_image_plane_pc(
-            rgb_frame.rgb, depth=20
-        )
-        image_plane_pc = image_plane_pc.transform(rgb_frame.pose.sensor_to_ego.inverse())
-        o3d_vis.add_pointcloud(image_plane_pc, color=colors)
-    o3d_vis.run()
-    del o3d_vis
-
-
-if __name__ == "__main__":
-    # Take arguments to specify dataset and root directory
-    parser = argparse.ArgumentParser()
-    parser.add_argument("--dataset", type=str, default="Argoverse2SceneFlow")
-    parser.add_argument("--root_dir", type=str, default="/efs/argoverse2/val")
-    parser.add_argument("--skip_rgb", action="store_true")
-    args = parser.parse_args()
-
-    dataset = construct_dataset(
-        args.dataset, dict(root_dir=args.root_dir, with_rgb=not args.skip_rgb)
-    )
-
-    print("Dataset contains", len(dataset), "samples")
-
-    for idx, (query, gt) in enumerate(dataset):
-        visualize_lidar_3d(query)
diff --git a/scripts/demo_3d.py b/scripts/demo_3d.py
new file mode 100644
index 0000000..104863b
--- /dev/null
+++ b/scripts/demo_3d.py
@@ -0,0 +1,54 @@
+import argparse
+from pathlib import Path
+
+from bucketed_scene_flow_eval.datasets import construct_dataset
+from bucketed_scene_flow_eval.datastructures import (
+    O3DVisualizer,
+    TimeSyncedSceneFlowItem,
+)
+
+
+def visualize_lidar_3d(frame_list: list[TimeSyncedSceneFlowItem], downscale_rgb_factor: int):
+    o3d_vis = O3DVisualizer()
+
+    print("Visualizing", len(frame_list), "frames")
+
+    for frame_idx, frame in enumerate(frame_list):
+        rgb_frames = frame.rgbs
+        pc_frame = frame.pc
+
+        o3d_vis.add_global_pc_frame(pc_frame)
+        for name, rgb_frame in rgb_frames.items():
+            print(f"Adding RGB frame {frame_idx} {name}")
+            rgb_frame = rgb_frame.rescale(downscale_rgb_factor)
+            # print("RGB Frame ego pose:", rgb_frame.pose.ego_to_global.translation)
+            o3d_vis.add_pose(rgb_frame.pose.ego_to_global)
+            o3d_vis.add_global_rgb_frame(rgb_frame)
+    o3d_vis.run()
+    del o3d_vis
+
+
+if __name__ == "__main__":
+    # Take arguments to specify dataset and root directory
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--dataset", type=str, default="Argoverse2SceneFlow")
+    parser.add_argument("--root_dir", type=Path, default="/efs/argoverse2/val")
+    parser.add_argument("--with_rgb", action="store_true")
+    parser.add_argument("--sequence_length", type=int, default=2)
+    parser.add_argument("--downscale_rgb_factor", type=int, default=8)
+    args = parser.parse_args()
+
+    dataset = construct_dataset(
+        args.dataset,
+        dict(
+            root_dir=args.root_dir, with_rgb=args.with_rgb, subsequence_length=args.sequence_length
+        ),
+    )
+    assert len(dataset) > 0, "Dataset is empty"
+    print("Dataset contains", len(dataset), "samples")
+
+    vis_index = args.sequence_length
+
+    print("Loading sequence idx", vis_index)
+    frame_list = dataset[vis_index]
+    visualize_lidar_3d(frame_list, args.downscale_rgb_factor)
diff --git a/tests/argoverse2/av2_tests.py b/tests/argoverse2/av2_tests.py
index 84c1310..15cbb9a 100644
--- a/tests/argoverse2/av2_tests.py
+++ b/tests/argoverse2/av2_tests.py
@@ -6,8 +6,12 @@
 from bucketed_scene_flow_eval.datasets.argoverse2 import (
     ArgoverseSceneFlowSequenceLoader,
 )
-from bucketed_scene_flow_eval.datasets.shared_datastructures import RawItem
-from bucketed_scene_flow_eval.datastructures import SE3, PoseInfo
+from bucketed_scene_flow_eval.datastructures import (
+    SE3,
+    PoseInfo,
+    TimeSyncedAVLidarData,
+    TimeSyncedSceneFlowItem,
+)
 
 
 @pytest.fixture
@@ -32,7 +36,9 @@ def _are_poseinfos_close(pose1: PoseInfo, pose2: PoseInfo, tol: float = 1e-6) ->
     )
 
 
-def _load_reference_sequence(av2_loader: ArgoverseSceneFlowSequenceLoader) -> RawItem:
+def _load_reference_sequence(
+    av2_loader: ArgoverseSceneFlowSequenceLoader,
+) -> tuple[TimeSyncedSceneFlowItem, TimeSyncedAVLidarData]:
     sequence_id = "02678d04-cc9f-3148-9f95-1ba66347dff9"
     assert sequence_id in av2_loader.get_sequence_ids(), f"sequence_id {sequence_id} not found"
     sequence = av2_loader.load_sequence(sequence_id)
@@ -41,7 +47,7 @@ def _load_reference_sequence(av2_loader: ArgoverseSceneFlowSequenceLoader) -> Ra
 
 
 def test_rgb_sizes(av2_loader: ArgoverseSceneFlowSequenceLoader):
-    first_frame = _load_reference_sequence(av2_loader)
+    first_frame, _ = _load_reference_sequence(av2_loader)
     assert len(first_frame.rgbs) == 5, f"expected 5 cameras, got {len(first_frame.rgbs)}"
 
     # Expect the shapes
@@ -65,7 +71,7 @@ def test_rgb_sizes(av2_loader: ArgoverseSceneFlowSequenceLoader):
 
 
 def test_rgb_poses(av2_loader: ArgoverseSceneFlowSequenceLoader):
-    first_frame = _load_reference_sequence(av2_loader)
+    first_frame, _ = _load_reference_sequence(av2_loader)
     assert len(first_frame.rgbs) == 5, f"expected 5 cameras, got {len(first_frame.rgbs)}"
     # fmt: off
     expected_poses = {
diff --git a/tests/eval/bucketed_epe.py b/tests/eval/bucketed_epe.py
index 6e15754..b49b65d 100644
--- a/tests/eval/bucketed_epe.py
+++ b/tests/eval/bucketed_epe.py
@@ -2,7 +2,10 @@
 import pytest
 
 from bucketed_scene_flow_eval.datasets import Argoverse2SceneFlow, construct_dataset
-from bucketed_scene_flow_eval.datastructures import GroundTruthPointFlow
+from bucketed_scene_flow_eval.datastructures import (
+    EgoLidarFlow,
+    TimeSyncedSceneFlowItem,
+)
 
 
 @pytest.fixture
@@ -72,18 +75,24 @@ def _run_eval_on_target_and_gt_datasets(
 
     # Iterate over both datasets, treating the pseudo dataset as the "prediction"
     # and the ground truth dataset as the "target"
-    for (_, est_gt), (_, target_gt) in zip(target_dataset, gt_dataset):
-        est_gt: GroundTruthPointFlow
-        target_gt: GroundTruthPointFlow
-        assert all(est_gt.trajectory_timestamps == target_gt.trajectory_timestamps), (
-            f"Timestamps must match between the ground truth and pseudo datasets. "
-            f"Found {est_gt.trajectory_timestamps} and {target_gt.trajectory_timestamps}."
+    iterations = 0
+    for target_lst, gt_lst in zip(target_dataset, gt_dataset):
+        assert len(target_lst) == len(gt_lst) == 2, (
+            f"Each sample must be a tuple of length 2. "
+            f"Found {len(target_lst)} and {len(gt_lst)}."
         )
-        assert (
-            len(target_gt.trajectory_timestamps) == 2
-        ), f"Timestamps must be a pair of timestamps. Found {target_gt.trajectory_timestamps}."
+        target_item1: TimeSyncedSceneFlowItem = target_lst[0]
+        gt_item1: TimeSyncedSceneFlowItem = gt_lst[0]
 
-        evaluator.eval(est_gt, target_gt, target_gt.trajectory_timestamps[0])
+        evaluator.eval(
+            target_item1.flow,
+            gt_item1,
+        )
+        iterations += 1
+
+    assert iterations == len(
+        gt_dataset
+    ), f"Expected to iterate over {len(gt_dataset)} samples, but only iterated over {iterations}."
 
     out_results_dict: dict[str, tuple[float, float]] = evaluator.compute_results()
 
@@ -98,6 +107,8 @@ def _run_eval_on_target_and_gt_datasets(
         f"Found {out_results_dict.keys()} and {EXPECTED_RESULTS_DICT.keys()}."
     )
 
+    print(out_results_dict)
+
     for key in EXPECTED_RESULTS_DICT:
         out_static_epe, out_dynamic_epe = out_results_dict[key]
         exp_static_epe, exp_dynamic_epe = EXPECTED_RESULTS_DICT[key]
@@ -105,11 +116,11 @@ def _run_eval_on_target_and_gt_datasets(
         # Check that floats are equal, but be aware of NaNs (which are not equal to anything)
         assert np.isnan(out_static_epe) == np.isnan(
             exp_static_epe
-        ), f"Static EPEs must both be NaN or not NaN. Found {out_static_epe} and {exp_static_epe}."
+        ), f"Static EPEs must both be NaN or not NaN. Found output is {out_static_epe} but expected {exp_static_epe}."
 
         assert np.isnan(out_dynamic_epe) == np.isnan(
             exp_dynamic_epe
-        ), f"Dynamic EPEs must both be NaN or not NaN. Found {out_dynamic_epe} and {exp_dynamic_epe}."
+        ), f"Dynamic EPEs must both be NaN or not NaN. Found output is {out_dynamic_epe} but expected {exp_dynamic_epe}."
 
         if not np.isnan(exp_static_epe):
             assert out_static_epe == pytest.approx(
diff --git a/tests/integration_tests.py b/tests/integration_tests.py
index 7fddafb..99df037 100644
--- a/tests/integration_tests.py
+++ b/tests/integration_tests.py
@@ -6,6 +6,7 @@
 
 from bucketed_scene_flow_eval.datasets import construct_dataset
 from bucketed_scene_flow_eval.datastructures import *
+from bucketed_scene_flow_eval.interfaces import AbstractDataset
 
 
 @pytest.fixture
@@ -94,108 +95,10 @@ def argo_dataset_test_no_flow_with_ground():
     )
 
 
-def _process_query(
-    query: QuerySceneSequence,
-) -> tuple[tuple[np.ndarray, np.ndarray], tuple[SE3, SE3], list[np.ndarray], list[SE3]]:
-    assert (
-        len(query.query_flow_timestamps) == 2
-    ), f"Query {query} has more than two timestamps. Only Scene Flow problems are supported."
-    scene = query.scene_sequence
-
-    # These contain the all problem percepts, not just the ones in the query.
-    all_percept_pc_arrays: list[np.ndarray] = []
-    all_percept_poses: list[SE3] = []
-    # These contain only the percepts in the query.
-    query_pc_arrays: list[np.ndarray] = []
-    query_poses: list[SE3] = []
-
-    for timestamp in scene.get_percept_timesteps():
-        pc_frame = scene[timestamp].pc_frame
-        pc_array = pc_frame.full_global_pc.points.astype(np.float32)
-        pose = pc_frame.global_pose
-
-        all_percept_pc_arrays.append(pc_array)
-        all_percept_poses.append(pose)
-
-        if timestamp in query.query_flow_timestamps:
-            query_pc_arrays.append(pc_array)
-            query_poses.append(pose)
-
-    assert len(all_percept_pc_arrays) == len(
-        all_percept_poses
-    ), f"Percept arrays and poses have different lengths."
-    assert len(query_pc_arrays) == len(
-        query_poses
-    ), f"Percept arrays and poses have different lengths."
-    assert len(query_pc_arrays) == len(
-        query.query_flow_timestamps
-    ), f"Percept arrays and poses have different lengths."
-
-    return (
-        query_pc_arrays,
-        query_poses,
-        all_percept_pc_arrays,
-        all_percept_poses,
-    )
-
-
-def _process_gt(result: GroundTruthPointFlow):
-    flowed_source_pc = result.world_points[:, 1].astype(np.float32)
-    is_valid_mask = result.is_valid_flow
-    point_cls_array = result.cls_ids
-    return flowed_source_pc, is_valid_mask, point_cls_array
-
-
-def _validate_dataloader_elements(
-    query: QuerySceneSequence,
-    gt: GroundTruthPointFlow,
-    expected_pc_size: int,
-    expected_is_valid_entries: int,
-):
-    assert isinstance(query, QuerySceneSequence), f"Expected QuerySceneSequence, got {type(query)}"
-    assert isinstance(
-        gt, GroundTruthPointFlow
-    ), f"Expected GroundTruthParticleTrajectories, got {type(gt)}"
-
-    t1, t2 = query.scene_sequence.get_percept_timesteps()
-    pc_frame = query.scene_sequence[t1].pc_frame
-
-    assert (
-        len(pc_frame.global_pc) == expected_pc_size
-    ), f"Expected {expected_pc_size} points, got {len(pc_frame.global_pc)}"
-
-    assert (
-        gt.is_valid_flow.sum() == expected_is_valid_entries
-    ), f"Expected {expected_is_valid_entries} valid entries, got {gt.is_valid_flow.sum()}"
-
-    (
-        (source_pc, target_pc),
-        (source_pose, target_pose),
-        full_pc_points_list,
-        full_pc_poses_list,
-    ) = _process_query(query)
-
-    gt_flowed_source_pc, is_valid_flow_mask, gt_point_classes = _process_gt(gt)
-
-    assert (
-        source_pc.shape == gt_flowed_source_pc.shape
-    ), f"Source PC shape mismatch: {source_pc.shape} vs {gt_flowed_source_pc.shape}"
-
-    assert source_pc.shape[0] == is_valid_flow_mask.shape[0], (
-        f"Source PC and is_valid_flow_mask shape mismatch: "
-        f"{source_pc.shape[0]} vs {is_valid_flow_mask.shape[0]}"
-    )
-
-    assert gt_point_classes.shape[0] == is_valid_flow_mask.shape[0], (
-        f"Point classes and is_valid_flow_mask shape mismatch: "
-        f"{gt_point_classes.shape[0]} vs {is_valid_flow_mask.shape[0]}"
-    )
-
-
 def _validate_dataloader(
-    dataloader,
-    pc_size: int,
-    is_valid_entries: int,
+    dataloader: AbstractDataset,
+    full_pc_size: int,
+    masked_pc_size: int,
     expected_len: int = 1,
 ):
     assert len(dataloader) == expected_len, f"Expected {expected_len} scene, got {len(dataloader)}"
@@ -205,10 +108,18 @@ def _validate_dataloader(
 
     num_iteration_entries = 0
     for entry in dataloader:
-        assert isinstance(entry, tuple), f"Expected tuple, got {type(entry)}"
-        assert len(entry) == 2, f"Expected tuple of length 2, got {len(entry)}"
-        query, gt = entry
-        _validate_dataloader_elements(query, gt, pc_size, is_valid_entries)
+        assert isinstance(entry, list), f"Expected list, got {type(entry)}"
+        assert len(entry) == 2, f"Expected list of length 2, got {len(entry)}"
+        item_t1, _ = entry
+
+        assert (
+            full_pc_size == item_t1.pc.full_pc.shape[0]
+        ), f"Expected full pc to be of size {full_pc_size}, got {item_t1.pc.full_pc.shape[0]}"
+
+        assert (
+            masked_pc_size == item_t1.pc.pc.shape[0]
+        ), f"Expected masked pc to be of size {masked_pc_size}, got {item_t1.pc.pc.shape[0]}"
+
         num_iteration_entries += 1
 
     # Check that we actually iterated over the dataset.
@@ -217,8 +128,8 @@ def _validate_dataloader(
     ), f"Expected {expected_len} iteration, got {num_iteration_entries}"
 
 
-def test_waymo_dataset(waymo_dataset_gt):
-    _validate_dataloader(waymo_dataset_gt, 124364, 124364)
+# def test_waymo_dataset(waymo_dataset_gt):
+#     _validate_dataloader(waymo_dataset_gt, 124364, 124364)
 
 
 def test_argo_dataset_gt_with_ground(argo_dataset_gt_with_ground):