From d1270b936c987571dcdb47b3aa95c3e190d9ee69 Mon Sep 17 00:00:00 2001 From: Boris Ivanovic Date: Wed, 6 Jul 2022 22:59:52 -0700 Subject: [PATCH] Partially fixes #1 and adds scene-centric functionality, fixing #2. --- README.md | 1 - examples/batch_example.py | 2 +- examples/scene_batch_example.py | 51 +++ setup.cfg | 2 +- src/trajdata/augmentation/augmentation.py | 7 +- src/trajdata/augmentation/noise_histories.py | 9 +- src/trajdata/data_structures/batch.py | 67 +++- src/trajdata/data_structures/batch_element.py | 266 ++++++++++++- src/trajdata/data_structures/collation.py | 374 +++++++++++++++++- src/trajdata/data_structures/scene.py | 19 +- src/trajdata/dataset.py | 26 +- .../eth_ucy_peds/eupeds_dataset.py | 8 +- .../dataset_specific/nusc/nusc_dataset.py | 12 +- .../dataset_specific/scene_records.py | 2 + src/trajdata/simulation/sim_df_cache.py | 2 +- src/trajdata/simulation/sim_scene.py | 3 +- src/trajdata/utils/arr_utils.py | 71 ++++ src/trajdata/visualization/vis.py | 117 +++++- 18 files changed, 978 insertions(+), 61 deletions(-) create mode 100644 examples/scene_batch_example.py diff --git a/README.md b/README.md index 24ae05a..a5cac3e 100644 --- a/README.md +++ b/README.md @@ -170,7 +170,6 @@ for t in range(1, sim_scene.scene_info.length_timesteps): `examples/sim_example.py` contains a more comprehensive example which initializes a simulation from a scene in the nuScenes mini dataset, steps through it by replaying agents' GT motions, and computes metrics based on scene statistics (e.g., displacement error from the original GT data, velocity/acceleration/jerk histograms). ## TODO -- Merge in upstream scene batch pull request. - Create a method like finalize() which writes all the batch information to a TFRecord/WebDataset/some other format which is (very) fast to read from for higher epoch training. - Add more examples to the README. - Finish README section about how to add a new dataset. diff --git a/examples/batch_example.py b/examples/batch_example.py index b7aad80..8a95761 100644 --- a/examples/batch_example.py +++ b/examples/batch_example.py @@ -19,7 +19,7 @@ def main(): future_sec=(4.8, 4.8), only_types=[AgentType.VEHICLE], agent_interaction_distances=defaultdict(lambda: 30.0), - incl_robot_future=True, + incl_robot_future=False, incl_map=True, map_params={"px_per_m": 2, "map_size_px": 224, "offset_frac_xy": (-0.5, 0.0)}, augmentations=[noise_hists], diff --git a/examples/scene_batch_example.py b/examples/scene_batch_example.py new file mode 100644 index 0000000..075cca9 --- /dev/null +++ b/examples/scene_batch_example.py @@ -0,0 +1,51 @@ +from collections import defaultdict + +from torch.utils.data import DataLoader +from tqdm import tqdm + +from trajdata import AgentBatch, AgentType, UnifiedDataset +from trajdata.augmentation import NoiseHistories +from trajdata.visualization.vis import plot_scene_batch + + +def main(): + noise_hists = NoiseHistories() + + dataset = UnifiedDataset( + desired_data=["nusc_mini-mini_train"], + centric="scene", + desired_dt=0.1, + history_sec=(3.2, 3.2), + future_sec=(4.8, 4.8), + only_types=[AgentType.VEHICLE], + agent_interaction_distances=defaultdict(lambda: 30.0), + incl_robot_future=True, + incl_map=True, + map_params={"px_per_m": 2, "map_size_px": 224, "offset_frac_xy": (-0.5, 0.0)}, + augmentations=[noise_hists], + max_agent_num=20, + num_workers=4, + verbose=True, + data_dirs={ # Remember to change this to match your filesystem! + "nusc_mini": "~/datasets/nuScenes", + }, + ) + + print(f"# Data Samples: {len(dataset):,}") + + dataloader = DataLoader( + dataset, + batch_size=4, + shuffle=True, + collate_fn=dataset.get_collate_fn(), + num_workers=4, + persistent_workers=True, + ) + + batch: AgentBatch + for batch in tqdm(dataloader): + plot_scene_batch(batch, batch_idx=0) + + +if __name__ == "__main__": + main() diff --git a/setup.cfg b/setup.cfg index d4515ee..8766e68 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = trajdata -version = 1.0.4 +version = 1.0.5 author = Boris Ivanovic author_email = bivanovic@nvidia.com description = A unified interface to many trajectory forecasting datasets. diff --git a/src/trajdata/augmentation/augmentation.py b/src/trajdata/augmentation/augmentation.py index 94163ed..6e64c3a 100644 --- a/src/trajdata/augmentation/augmentation.py +++ b/src/trajdata/augmentation/augmentation.py @@ -1,6 +1,6 @@ import pandas as pd -from trajdata.data_structures.batch import AgentBatch +from trajdata.data_structures.batch import AgentBatch, SceneBatch class Augmentation: @@ -14,5 +14,8 @@ def apply(self, scene_data_df: pd.DataFrame) -> None: class BatchAugmentation(Augmentation): - def apply(self, agent_batch: AgentBatch) -> None: + def apply_agent(self, agent_batch: AgentBatch) -> None: + raise NotImplementedError() + + def apply_scene(self, scene_batch: SceneBatch) -> None: raise NotImplementedError() diff --git a/src/trajdata/augmentation/noise_histories.py b/src/trajdata/augmentation/noise_histories.py index e48dcf7..64be56e 100644 --- a/src/trajdata/augmentation/noise_histories.py +++ b/src/trajdata/augmentation/noise_histories.py @@ -1,7 +1,7 @@ import torch from trajdata.augmentation.augmentation import BatchAugmentation -from trajdata.data_structures.batch import AgentBatch +from trajdata.data_structures.batch import AgentBatch, SceneBatch class NoiseHistories(BatchAugmentation): @@ -9,10 +9,15 @@ def __init__(self, mean: float = 0.0, stddev: float = 0.1) -> None: self.mean = mean self.stddev = stddev - def apply(self, agent_batch: AgentBatch) -> None: + def apply_agent(self, agent_batch: AgentBatch) -> None: agent_batch.agent_hist[..., :-1, :] += torch.normal( self.mean, self.stddev, size=agent_batch.agent_hist[..., :-1, :].shape ) agent_batch.neigh_hist[..., :-1, :] += torch.normal( self.mean, self.stddev, size=agent_batch.neigh_hist[..., :-1, :].shape ) + + def apply_scene(self, scene_batch: SceneBatch) -> None: + scene_batch.agent_hist[..., :-1, :] += torch.normal( + self.mean, self.stddev, size=scene_batch.agent_hist[..., :-1, :].shape + ) diff --git a/src/trajdata/data_structures/batch.py b/src/trajdata/data_structures/batch.py index a7f71f4..17df5fe 100644 --- a/src/trajdata/data_structures/batch.py +++ b/src/trajdata/data_structures/batch.py @@ -37,6 +37,7 @@ class AgentBatch: maps_resolution: Optional[Tensor] rasters_from_world_tf: Optional[Tensor] agents_from_world_tf: Tensor + scene_ids: Optional[List] def to(self, device) -> None: excl_vals = { @@ -50,6 +51,7 @@ def to(self, device) -> None: "neigh_types", "num_neigh", "robot_fut_len", + "scene_ids", } for val in vars(self).keys(): tensor_val = getattr(self, val) @@ -96,7 +98,70 @@ def for_agent_type(self, agent_type: AgentType) -> AgentBatch: if self.rasters_from_world_tf is not None else None, agents_from_world_tf=self.agents_from_world_tf[match_type], + scene_ids=[ + scene_id + for idx, scene_id in enumerate(self.scene_ids) + if match_type[idx] + ], ) -SceneBatch = namedtuple("SceneBatch", "") +@dataclass +class SceneBatch: + data_idx: Tensor + dt: Tensor + num_agents: Tensor + agent_type: Tensor + centered_agent_state: Tensor + agent_hist: Tensor + agent_hist_extent: Tensor + agent_hist_len: Tensor + agent_fut: Tensor + agent_fut_extent: Tensor + agent_fut_len: Tensor + robot_fut: Optional[Tensor] + robot_fut_len: Optional[Tensor] + maps: Optional[Tensor] + maps_resolution: Optional[Tensor] + rasters_from_world_tf: Optional[Tensor] + centered_agent_from_world_tf: Tensor + centered_world_from_agent_tf: Tensor + + def to(self, device) -> None: + for val in vars(self).keys(): + tensor_val = getattr(self, val) + if tensor_val is not None: + setattr(self, val, tensor_val.to(device)) + + def agent_types(self) -> List[AgentType]: + unique_types: Tensor = torch.unique(self.agent_type) + return [AgentType(unique_type.item()) for unique_type in unique_types] + + def for_agent_type(self, agent_type: AgentType) -> AgentBatch: + match_type = self.agent_type == agent_type + return SceneBatch( + data_idx=self.data_idx[match_type], + dt=self.dt[match_type], + num_agents=self.num_agents[match_type], + agent_type=self.agent_type[match_type], + centered_agent_state=self.centered_agent_state[match_type], + agent_hist=self.agent_hist[match_type], + agent_hist_extent=self.agent_hist_extent[match_type], + agent_hist_len=self.agent_hist_len[match_type], + agent_fut=self.agent_fut[match_type], + agent_fut_extent=self.agent_fut_extent[match_type], + agent_fut_len=self.agent_fut_len[match_type], + robot_fut=self.robot_fut[match_type] + if self.robot_fut is not None + else None, + robot_fut_len=self.robot_fut_len[match_type], + maps=self.maps[match_type] if self.maps is not None else None, + maps_resolution=self.maps_resolution[match_type] + if self.maps_resolution is not None + else None, + rasters_from_world_tf=self.rasters_from_world_tf[match_type] + if self.rasters_from_world_tf is not None + else None, + centered_agent_from_world_tf=self.centered_agent_from_world_tf[match_type], + centered_world_from_agent_tf=self.centered_world_from_agent_tf[match_type], + ) diff --git a/src/trajdata/data_structures/batch_element.py b/src/trajdata/data_structures/batch_element.py index 7872ed9..6013087 100644 --- a/src/trajdata/data_structures/batch_element.py +++ b/src/trajdata/data_structures/batch_element.py @@ -5,7 +5,7 @@ import numpy as np from trajdata.caching import SceneCache -from trajdata.data_structures.agent import AgentMetadata, AgentType, FixedExtent +from trajdata.data_structures.agent import Agent, AgentMetadata, AgentType, FixedExtent from trajdata.data_structures.map_patch import MapPatch from trajdata.data_structures.scene import SceneTime, SceneTimeAgent @@ -28,6 +28,7 @@ def __init__( incl_map: bool = False, map_params: Optional[Dict[str, int]] = None, standardize_data: bool = False, + standardize_derivatives: bool = False, ) -> None: self.cache: SceneCache = cache self.data_index: int = data_index @@ -57,11 +58,16 @@ def __init__( ) self.agent_from_world_tf: np.ndarray = np.linalg.inv(world_from_agent_tf) + offset = self.curr_agent_state_np + if not standardize_derivatives: + offset[2:6] = 0.0 + cache.transform_data( - shift_mean_to=self.curr_agent_state_np, + shift_mean_to=offset, rotate_by=agent_heading, sincos_heading=True, ) + else: self.agent_from_world_tf: np.ndarray = np.eye(3) @@ -107,6 +113,7 @@ def distance_limit(agent_types: np.ndarray, target_type: int) -> np.ndarray: ### ROBOT DATA ### self.robot_future_np: Optional[np.ndarray] = None + if incl_robot_future: self.robot_future_np: np.ndarray = self.get_robot_current_and_future( scene_time_agent.robot, future_sec @@ -122,6 +129,8 @@ def distance_limit(agent_types: np.ndarray, target_type: int) -> np.ndarray: if incl_map: self.map_patch = self.get_agent_map_patch(map_params) + self.scene_id = scene_time_agent.scene.name + def get_agent_history( self, agent_info: AgentMetadata, @@ -142,7 +151,6 @@ def get_agent_future( ) return agent_future_np, agent_extent_future_np - # @profile def get_neighbor_history( self, scene_time: SceneTimeAgent, @@ -154,7 +162,6 @@ def get_neighbor_history( # which would have a distance of 0 to itself). agent_distances: np.ndarray = scene_time.get_agent_distances_to(agent_info) agent_idx: int = scene_time.agents.index(agent_info) - neighbor_types: np.ndarray = np.array([a.type.value for a in scene_time.agents]) nearby_mask: np.ndarray = agent_distances <= distance_limit( neighbor_types, agent_info.type @@ -284,8 +291,253 @@ class SceneBatchElement: def __init__( self, + cache: SceneCache, + data_index: int, scene_time: SceneTime, - history_sec_at_most: float, - future_sec_at_most: float, + history_sec: Tuple[Optional[float], Optional[float]], + future_sec: Tuple[Optional[float], Optional[float]], + agent_interaction_distances: Dict[ + Tuple[AgentType, AgentType], float + ] = defaultdict(lambda: np.inf), + incl_robot_future: bool = False, + incl_map: bool = False, + map_params: Optional[Dict[str, int]] = None, + standardize_data: bool = False, + standardize_derivatives: bool = False, + max_agent_num: Optional[int] = None, ) -> None: - self.history_sec_at_most = history_sec_at_most + self.cache: SceneCache = cache + self.data_index = data_index + self.dt: float = scene_time.scene.dt + self.scene_ts: int = scene_time.ts + + if max_agent_num is not None: + scene_time.agents = scene_time.agents[:max_agent_num] + + self.agents: List[AgentMetadata] = scene_time.agents + + robot = [agent for agent in self.agents if agent.name == "ego"] + if len(robot) > 0: + self.centered_agent = robot[0] + else: + self.centered_agent = self.agents[0] + + self.centered_agent_state_np: np.ndarray = cache.get_state( + self.centered_agent.name, self.scene_ts + ) + self.standardize_data = standardize_data + + if self.standardize_data: + agent_pos: np.ndarray = self.centered_agent_state_np[:2] + agent_heading: float = self.centered_agent_state_np[-1] + + cos_agent, sin_agent = np.cos(agent_heading), np.sin(agent_heading) + self.centered_world_from_agent_tf: np.ndarray = np.array( + [ + [cos_agent, -sin_agent, agent_pos[0]], + [sin_agent, cos_agent, agent_pos[1]], + [0.0, 0.0, 1.0], + ] + ) + self.centered_agent_from_world_tf: np.ndarray = np.linalg.inv( + self.centered_world_from_agent_tf + ) + + offset = self.centered_agent_state_np + if not standardize_derivatives: + offset[2:6] = 0.0 + + cache.transform_data( + shift_mean_to=offset, + rotate_by=agent_heading, + sincos_heading=True, + ) + else: + self.agent_from_world_tf: np.ndarray = np.eye(3) + + ### NEIGHBOR-SPECIFIC DATA ### + def distance_limit(agent_types: np.ndarray, target_type: int) -> np.ndarray: + return np.array( + [ + agent_interaction_distances[(agent_type, target_type)] + for agent_type in agent_types + ] + ) + + nearby_agents, self.agent_types_np = self.get_nearby_agents( + scene_time, self.centered_agent, distance_limit + ) + + self.num_agents = len(nearby_agents) + ( + self.agent_histories, + self.agent_history_extents, + self.agent_history_lens_np, + ) = self.get_agents_history(history_sec, nearby_agents) + ( + self.agent_futures, + self.agent_future_extents, + self.agent_future_lens_np, + ) = self.get_agents_future(future_sec, nearby_agents) + + ### MAP ### + self.map_patches: Optional[MapPatch] = None + if incl_map: + self.map_patches = self.get_agents_map_patch( + map_params, self.agent_histories + ) + + ### ROBOT DATA ### + self.robot_future_np: Optional[np.ndarray] = None + + if incl_robot_future: + self.robot_future_np: np.ndarray = self.get_robot_current_and_future( + self.centered_agent, future_sec + ) + + # -1 because this is meant to hold the number of future steps + # (whereas the above returns the current + future, yielding + # one more timestep). + self.robot_future_len: int = self.robot_future_np.shape[0] - 1 + + def get_nearby_agents( + self, + scene_time: SceneTime, + agent: AgentMetadata, + distance_limit: Callable[[np.ndarray, int], np.ndarray], + ) -> Tuple[List[AgentMetadata], np.ndarray]: + agent_distances: np.ndarray = scene_time.get_agent_distances_to(agent) + + agents_types: np.ndarray = np.array([a.type.value for a in scene_time.agents]) + nearby_mask: np.ndarray = agent_distances <= distance_limit( + agents_types, agent.type + ) + # sort the agents based on their distance to the centered agent + idx = np.argsort(agent_distances) + num_qualified = nearby_mask.sum() + nearby_agents: List[AgentMetadata] = [ + scene_time.agents[idx[i]] for i in range(num_qualified) + ] + agents_types_np = agents_types[idx[:num_qualified]] + return nearby_agents, agents_types_np + + def get_agents_history( + self, + history_sec: Tuple[Optional[float], Optional[float]], + nearby_agents: List[AgentMetadata], + ) -> Tuple[List[np.ndarray], List[np.ndarray], np.ndarray]: + # The indices of the returned ndarray match the scene_time agents list (including the index of the central agent, + # which would have a distance of 0 to itself). + ( + agent_histories, + agent_history_extents, + agent_history_lens_np, + ) = self.cache.get_agents_history(self.scene_ts, nearby_agents, history_sec) + + return ( + agent_histories, + agent_history_extents, + agent_history_lens_np, + ) + + def get_agents_future( + self, + future_sec: Tuple[Optional[float], Optional[float]], + nearby_agents: List[AgentMetadata], + ) -> Tuple[List[np.ndarray], List[np.ndarray], np.ndarray]: + + ( + agent_futures, + agent_future_extents, + agent_future_lens_np, + ) = self.cache.get_agents_future(self.scene_ts, nearby_agents, future_sec) + + return ( + agent_futures, + agent_future_extents, + agent_future_lens_np, + ) + + def get_agents_map_patch( + self, patch_params: Dict[str, int], agent_histories: List[np.ndarray] + ) -> List[MapPatch]: + world_x, world_y = self.centered_agent_state_np[:2] + heading = self.centered_agent_state_np[-1] + desired_patch_size: int = patch_params["map_size_px"] + resolution: int = patch_params["px_per_m"] + offset_xy: Tuple[float, float] = patch_params.get("offset_frac_xy", (0.0, 0.0)) + return_rgb: bool = patch_params.get("return_rgb", True) + + if len(self.cache.heading_cols) == 1: + heading_idx = self.cache.heading_cols[0] + sincos = False + else: + heading_sin_idx, heading_cos_idx = self.cache.heading_cols + sincos = True + x_idx, y_idx = self.cache.pos_cols + + map_patches = list() + for agent_his in agent_histories: + if self.standardize_data: + if sincos: + agent_heading = ( + np.arctan2( + agent_his[-1, heading_sin_idx], + agent_his[-1, heading_cos_idx], + ) + + heading + ) + else: + agent_heading = agent_his[-1, heading_idx] + heading + + patch_data, raster_from_world_tf = self.cache.load_map_patch( + world_x + agent_his[-1, x_idx], + world_y + agent_his[-1, y_idx], + desired_patch_size, + resolution, + offset_xy, + agent_heading, + return_rgb, + rot_pad_factor=sqrt(2), + ) + + else: + agent_heading = 0.0 + patch_data, raster_from_world_tf = self.cache.load_map_patch( + agent_his[-1, x_idx], + agent_his[-1, y_idx], + desired_patch_size, + resolution, + offset_xy, + agent_heading, + return_rgb, + ) + + map_patches.append( + MapPatch( + data=patch_data, + rot_angle=agent_heading, + crop_size=desired_patch_size, + resolution=resolution, + raster_from_world_tf=raster_from_world_tf, + ) + ) + + return map_patches + + def get_robot_current_and_future( + self, + robot_info: AgentMetadata, + future_sec: Tuple[Optional[float], Optional[float]], + ) -> np.ndarray: + robot_curr_np: np.ndarray = self.cache.get_state(robot_info.name, self.scene_ts) + # robot_fut_extents_np, + ( + robot_fut_np, + _, + ) = self.cache.get_agent_future(robot_info, self.scene_ts, future_sec) + + robot_curr_and_fut_np: np.ndarray = np.concatenate( + (robot_curr_np[np.newaxis, :], robot_fut_np), axis=0 + ) + return robot_curr_and_fut_np diff --git a/src/trajdata/data_structures/collation.py b/src/trajdata/data_structures/collation.py index 2ad03a2..895cfb9 100644 --- a/src/trajdata/data_structures/collation.py +++ b/src/trajdata/data_structures/collation.py @@ -15,9 +15,9 @@ from trajdata.utils import arr_utils -def map_collate_fn( +def map_collate_fn_agent( batch_elems: List[AgentBatchElement], -) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]: +): if batch_elems[0].map_patch is None: return None, None, None @@ -38,7 +38,6 @@ def map_collate_fn( [batch_elem.map_patch.resolution for batch_elem in batch_elems], dtype=torch.float, ) - rasters_from_world_tf: Tensor = torch.as_tensor( np.stack( [batch_elem.map_patch.raster_from_world_tf for batch_elem in batch_elems] @@ -65,23 +64,118 @@ def map_collate_fn( rasters_from_world_tf, ) - return patch_data, resolution, rasters_from_world_tf + rot_crop_patches = patch_data + else: + + rot_crop_patches: Tensor = center_crop( + rotate(patch_data, torch.rad2deg(rot_angles)), (patch_size, patch_size) + ) + rasters_from_world_tf = torch.bmm( + arr_utils.transform_matrices( + -rot_angles, + torch.tensor([[patch_size // 2, patch_size // 2]]).expand( + (rot_angles.shape[0], -1) + ), + ), + rasters_from_world_tf, + ) + + return ( + rot_crop_patches, + resolution, + rasters_from_world_tf, + ) - rot_crop_patches: Tensor = center_crop( - rotate(patch_data, torch.rad2deg(rot_angles)), (patch_size, patch_size) + +def map_collate_fn_scene( + batch_elems: List[SceneBatchElement], + max_agent_num: Optional[int] = None, + pad_value: Any = np.nan, +) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]: + + if batch_elems[0].map_patches is None: + return None, None, None + + patch_size: int = batch_elems[0].map_patches[0].crop_size + assert all( + batch_elem.map_patches[0].crop_size == patch_size for batch_elem in batch_elems ) - rasters_from_world_tf = torch.bmm( - arr_utils.transform_matrices( - -rot_angles, - torch.tensor([[patch_size // 2, patch_size // 2]]).expand( - (rot_angles.shape[0], -1) + num_agents: List[int] = list() + agents_rasters_from_world_tfs: List[np.ndarray] = list() + agents_patches: List[np.ndarray] = list() + agents_rot_angles_list: List[float] = list() + agents_res_list: List[float] = list() + + for elem in batch_elems: + num_agents.append(min(elem.num_agents, max_agent_num)) + agents_rasters_from_world_tfs += [ + x.raster_from_world_tf for x in elem.map_patches[:max_agent_num] + ] + agents_patches += [x.data for x in elem.map_patches[:max_agent_num]] + agents_rot_angles_list += [ + x.rot_angle for x in elem.map_patches[:max_agent_num] + ] + agents_res_list += [x.resolution for x in elem.map_patches[:max_agent_num]] + + patch_data: Tensor = torch.as_tensor(np.stack(agents_patches), dtype=torch.float) + agents_rot_angles: Tensor = torch.as_tensor( + np.stack(agents_rot_angles_list), dtype=torch.float + ) + agents_rasters_from_world_tf: Tensor = torch.as_tensor( + np.stack(agents_rasters_from_world_tfs), dtype=torch.float + ) + agents_resolution: Tensor = torch.as_tensor( + np.stack(agents_res_list), dtype=torch.int + ) + + if torch.count_nonzero(agents_rot_angles) == 0: + agents_rasters_from_world_tf = torch.bmm( + torch.tensor( + [ + [ + [1.0, 0.0, patch_size // 2], + [0.0, 1.0, patch_size // 2], + [0.0, 0.0, 1.0], + ] + ], + dtype=agents_rasters_from_world_tf.dtype, + device=agents_rasters_from_world_tf.device, + ).expand((agents_rasters_from_world_tf.shape[0], -1, -1)), + agents_rasters_from_world_tf, + ) + + rot_crop_patches = patch_data + else: + agents_rasters_from_world_tf = torch.bmm( + arr_utils.transform_matrices( + -agents_rot_angles, + torch.tensor([[patch_size // 2, patch_size // 2]]).expand( + (agents_rot_angles.shape[0], -1) + ), ), - ), - rasters_from_world_tf, + agents_rasters_from_world_tf, + ) + rot_crop_patches = center_crop( + rotate(patch_data, torch.rad2deg(agents_rot_angles)), + (patch_size, patch_size), + ) + + rot_crop_patches = split_pad_crop( + rot_crop_patches, num_agents, pad_value=pad_value, desired_size=max_agent_num ) - return rot_crop_patches, resolution, rasters_from_world_tf + agents_rasters_from_world_tf = split_pad_crop( + agents_rasters_from_world_tf, + num_agents, + pad_value=pad_value, + desired_size=max_agent_num, + ) + agents_resolution = split_pad_crop( + agents_resolution, num_agents, pad_value=0, desired_size=max_agent_num + ) + + return rot_crop_patches, agents_resolution, agents_rasters_from_world_tf def agent_collate_fn( @@ -371,12 +465,19 @@ def agent_collate_fn( if robot_future else None ) - map_patches, maps_resolution, rasters_from_world_tf = map_collate_fn(batch_elems) + + ( + map_patches, + maps_resolution, + rasters_from_world_tf, + ) = map_collate_fn_agent(batch_elems) + agents_from_world_tf = torch.as_tensor( np.stack([batch_elem.agent_from_world_tf for batch_elem in batch_elems]), dtype=torch.float, ) + scene_ids = [batch_elem.scene_id for batch_elem in batch_elems] batch = AgentBatch( data_idx=data_index_t, dt=dt_t, @@ -403,21 +504,252 @@ def agent_collate_fn( maps_resolution=maps_resolution, rasters_from_world_tf=rasters_from_world_tf, agents_from_world_tf=agents_from_world_tf, + scene_ids=scene_ids, ) if batch_augments: for batch_aug in batch_augments: - batch_aug.apply(batch) + batch_aug.apply_agent(batch) if return_dict: return asdict(batch) - return batch -def scene_collate_fn(batch_elems: List[SceneBatchElement]) -> SceneBatch: - return SceneBatch( - nums=default_collate( - [batch_elem.history_sec_at_most for batch_elem in batch_elems] +def split_pad_crop( + batch_tensor, sizes, pad_value: float = 0.0, desired_size: Optional[int] = None +) -> Tensor: + """Split a batched tensor into different sizes and pad them to the same size + + Args: + batch_tensor: tensor in bach or split tensor list + sizes (torch.Tensor): sizes of each entry + pad_value (float, optional): padding value. Defaults to 0.0 + desired_size (int, optional): desired size. Defaults to None. + """ + + if isinstance(batch_tensor, Tensor): + x = torch.split(batch_tensor, sizes) + cat_fun = torch.cat + full_fun = torch.full + elif isinstance(batch_tensor, np.ndarray): + x = np.split(batch_tensor, sizes) + cat_fun = np.concatenate + full_fun = np.full + elif isinstance(batch_tensor, List): + # already splitted in list + x = batch_tensor + if isinstance(batch_tensor[0], Tensor): + cat_fun = torch.cat + full_fun = torch.full + elif isinstance(batch_tensor[0], np.ndarray): + cat_fun = np.concatenate + full_fun = np.full + else: + raise ValueError("wrong data type for batch tensor") + + x: Tensor = pad_sequence(x, batch_first=True, padding_value=pad_value) + if desired_size is not None: + if x.shape[1] >= desired_size: + x = x[:, :desired_size] + else: + bs, max_size = x.shape[:2] + x = cat_fun( + (x, full_fun([bs, desired_size - max_size, *x.shape[2:]], pad_value)), 1 + ) + + return x + + +def scene_collate_fn( + batch_elems: List[SceneBatchElement], + return_dict: bool, + batch_augments: Optional[List[BatchAugmentation]] = None, +) -> SceneBatch: + batch_size: int = len(batch_elems) + data_index_t: Tensor = torch.zeros((batch_size,), dtype=torch.int) + dt_t: Tensor = torch.zeros((batch_size,), dtype=torch.float) + + max_agent_num: int = max(elem.num_agents for elem in batch_elems) + + centered_agent_state: List[Tensor] = list() + agents_types: List[Tensor] = list() + agents_histories: List[Tensor] = list() + agents_history_extents: List[Tensor] = list() + agents_history_len: Tensor = torch.zeros( + (batch_size, max_agent_num), dtype=torch.long + ) + + agents_futures: List[Tensor] = list() + agents_future_extents: List[Tensor] = list() + agents_future_len: Tensor = torch.zeros( + (batch_size, max_agent_num), dtype=torch.long + ) + + num_agents: List[int] = [elem.num_agents for elem in batch_elems] + num_agents_t: Tensor = torch.as_tensor(num_agents, dtype=torch.long) + + max_history_len: int = max(elem.agent_history_lens_np.max() for elem in batch_elems) + max_future_len: int = max(elem.agent_future_lens_np.max() for elem in batch_elems) + + robot_future: List[Tensor] = list() + robot_future_len: Tensor = torch.zeros((batch_size,), dtype=torch.long) + + for idx, elem in enumerate(batch_elems): + data_index_t[idx] = elem.data_index + dt_t[idx] = elem.dt + centered_agent_state.append(elem.centered_agent_state_np) + agents_types.append(elem.agent_types_np) + history_len_i = torch.tensor( + [rec.shape[0] for rec in elem.agent_histories[:max_agent_num]] + ) + future_len_i = torch.tensor( + [rec.shape[0] for rec in elem.agent_futures[:max_agent_num]] + ) + agents_history_len[idx, : elem.num_agents] = history_len_i + agents_future_len[idx, : elem.num_agents] = future_len_i + + # History + padded_agents_histories = pad_sequence( + [ + torch.as_tensor(nh, dtype=torch.float).flip(-2) + for nh in elem.agent_histories[:max_agent_num] + ], + batch_first=True, + padding_value=np.nan, + ).flip(-2) + padded_agents_history_extents = pad_sequence( + [ + torch.as_tensor(nh, dtype=torch.float).flip(-2) + for nh in elem.agent_history_extents[:max_agent_num] + ], + batch_first=True, + padding_value=np.nan, + ).flip(-2) + if padded_agents_histories.shape[-2] < max_history_len: + to_add = max_history_len - padded_agents_histories.shape[-2] + padded_agents_histories = F.pad( + padded_agents_histories, + pad=(0, 0, to_add, 0), + mode="constant", + value=np.nan, + ) + padded_agents_history_extents = F.pad( + padded_agents_history_extents, + pad=(0, 0, to_add, 0), + mode="constant", + value=np.nan, + ) + + agents_histories.append(padded_agents_histories) + agents_history_extents.append(padded_agents_history_extents) + + # Future + padded_agents_futures = pad_sequence( + [ + torch.as_tensor(nh, dtype=torch.float) + for nh in elem.agent_futures[:max_agent_num] + ], + batch_first=True, + padding_value=np.nan, ) + padded_agents_future_extents = pad_sequence( + [ + torch.as_tensor(nh, dtype=torch.float) + for nh in elem.agent_future_extents + ], + batch_first=True, + padding_value=np.nan, + ) + if padded_agents_futures.shape[-2] < max_future_len: + to_add = max_future_len - padded_agents_futures.shape[-2] + padded_agents_futures = F.pad( + padded_agents_futures, + pad=(0, 0, 0, to_add), + mode="constant", + value=np.nan, + ) + padded_agents_future_extents = F.pad( + padded_agents_future_extents, + pad=(0, 0, 0, to_add), + mode="constant", + value=np.nan, + ) + + agents_futures.append(padded_agents_futures) + agents_future_extents.append(padded_agents_future_extents) + + if elem.robot_future_np is not None: + robot_future.append( + torch.as_tensor(elem.robot_future_np, dtype=torch.float) + ) + robot_future_len[idx] = elem.robot_future_len + + agents_histories_t = split_pad_crop( + agents_histories, num_agents, np.nan, max_agent_num + ) + agents_history_extents_t = split_pad_crop( + agents_history_extents, num_agents, np.nan, max_agent_num ) + agents_futures_t = split_pad_crop(agents_futures, num_agents, np.nan, max_agent_num) + agents_future_extents_t = split_pad_crop( + agents_future_extents, num_agents, np.nan, max_agent_num + ) + + centered_agent_state_t = torch.tensor(np.stack(centered_agent_state)) + agents_types_t = torch.as_tensor(np.concatenate(agents_types)) + agents_types_t = split_pad_crop( + agents_types_t, num_agents, pad_value=-1, desired_size=max_agent_num + ) + + map_patches, maps_resolution, rasters_from_world_tf = map_collate_fn_scene( + batch_elems, max_agent_num + ) + centered_agent_from_world_tf = torch.as_tensor( + np.stack( + [batch_elem.centered_agent_from_world_tf for batch_elem in batch_elems] + ), + dtype=torch.float, + ) + centered_world_from_agent_tf = torch.as_tensor( + np.stack( + [batch_elem.centered_world_from_agent_tf for batch_elem in batch_elems] + ), + dtype=torch.float, + ) + + robot_future_t: Optional[Tensor] = ( + pad_sequence(robot_future, batch_first=True, padding_value=np.nan) + if robot_future + else None + ) + + batch = SceneBatch( + data_idx=data_index_t, + dt=dt_t, + num_agents=num_agents_t, + agent_type=agents_types_t, + centered_agent_state=centered_agent_state_t, + agent_hist=agents_histories_t, + agent_hist_extent=agents_history_extents_t, + agent_hist_len=agents_history_len, + agent_fut=agents_futures_t, + agent_fut_extent=agents_future_extents_t, + agent_fut_len=agents_future_len, + robot_fut=robot_future_t, + robot_fut_len=robot_future_len, + maps=map_patches, + maps_resolution=maps_resolution, + rasters_from_world_tf=rasters_from_world_tf, + centered_agent_from_world_tf=centered_agent_from_world_tf, + centered_world_from_agent_tf=centered_world_from_agent_tf, + ) + + if batch_augments: + for batch_aug in batch_augments: + batch_aug.apply_scene(batch) + + if return_dict: + return asdict(batch) + + return batch diff --git a/src/trajdata/data_structures/scene.py b/src/trajdata/data_structures/scene.py index 9f1303b..d753148 100644 --- a/src/trajdata/data_structures/scene.py +++ b/src/trajdata/data_structures/scene.py @@ -38,24 +38,15 @@ def from_cache( agents_present, no_types, only_types ) - data_df: pd.DataFrame = cache.load_all_agent_data(scene) - - agents: List[Agent] = list() - for agent_info in filtered_agents: - agents.append(Agent(agent_info, data_df.loc[agent_info.name])) - - return cls(scene, scene_ts, agents, cache) + return cls(scene, scene_ts, filtered_agents, cache) def get_agent_distances_to(self, agent: Agent) -> np.ndarray: - agent_pos = np.array( - [[agent.data.at[self.ts, "x"], agent.data.at[self.ts, "y"]]] + agent_pos: np.ndarray = self.cache.get_state(agent.name, self.ts)[:2] + nb_pos: np.ndarray = np.stack( + [self.cache.get_state(nb.name, self.ts)[:2] for nb in self.agents] ) - data_df: pd.DataFrame = self.cache.load_agent_xy_at_time(self.ts, self.scene) - - agent_ids = [a.name for a in self.agents] - curr_poses = data_df.loc[agent_ids, ["x", "y"]].values - return np.linalg.norm(curr_poses - agent_pos, axis=1) + return np.linalg.norm(nb_pos - agent_pos, axis=1) class SceneTimeAgent: diff --git a/src/trajdata/dataset.py b/src/trajdata/dataset.py index 233e1d3..c6f906f 100644 --- a/src/trajdata/dataset.py +++ b/src/trajdata/dataset.py @@ -60,7 +60,9 @@ def __init__( only_types: Optional[List[AgentType]] = None, no_types: Optional[List[AgentType]] = None, standardize_data: bool = True, + standardize_derivatives: bool = True, augmentations: Optional[List[Augmentation]] = None, + max_agent_num: Optional[int] = None, data_dirs: Dict[str, str] = { # "nusc": "~/datasets/nuScenes", "eupeds_eth": "~/datasets/eth_ucy_peds", @@ -135,8 +137,10 @@ def __init__( self.only_types = None if only_types is None else set(only_types) self.no_types = None if no_types is None else set(no_types) self.standardize_data = standardize_data + self.standardize_derivatives = standardize_derivatives self.augmentations = augmentations self.verbose = verbose + self.max_agent_num = max_agent_num # Ensuring scene description queries are all lowercase if scene_description_contains is not None: @@ -388,7 +392,9 @@ def get_collate_fn(self, return_dict: bool = False) -> Callable: ) elif self.centric == "scene": collate_fn = partial( - scene_collate_fn, return_dict=return_dict, batch_augments=batch_augments + scene_collate_fn, + return_dict=return_dict, + batch_augments=batch_augments, ) return collate_fn @@ -576,7 +582,6 @@ def __len__(self) -> int: # @profile def __getitem__(self, idx: int) -> AgentBatchElement: scene_path, scene_index_elem = self._data_index[idx] - if self.centric == "scene": scene_info, _, scene_index_elems = UnifiedDataset._get_data_index_scene( (scene_path, None), @@ -618,8 +623,20 @@ def __getitem__(self, idx: int) -> AgentBatchElement: only_types=self.only_types, no_types=self.no_types, ) - - return SceneBatchElement(scene_time, self.history_sec, self.future_sec) + return SceneBatchElement( + scene_cache, + idx, + scene_time, + self.history_sec, + self.future_sec, + self.agent_interaction_distances, + self.incl_robot_future, + self.incl_map, + self.map_params, + self.standardize_data, + self.standardize_derivatives, + self.max_agent_num, + ) elif self.centric == "agent": scene_time_agent: SceneTimeAgent = SceneTimeAgent.from_cache( scene_info, @@ -642,4 +659,5 @@ def __getitem__(self, idx: int) -> AgentBatchElement: self.incl_map, self.map_params, self.standardize_data, + self.standardize_derivatives, ) diff --git a/src/trajdata/dataset_specific/eth_ucy_peds/eupeds_dataset.py b/src/trajdata/dataset_specific/eth_ucy_peds/eupeds_dataset.py index 593db68..74182c6 100644 --- a/src/trajdata/dataset_specific/eth_ucy_peds/eupeds_dataset.py +++ b/src/trajdata/dataset_specific/eth_ucy_peds/eupeds_dataset.py @@ -179,7 +179,13 @@ def _get_matching_scenes_from_cache( scenes_list: List[Scene] = list() for scene_record in all_scenes_list: - scene_name, scene_location, scene_length, scene_split, data_idx = scene_record + ( + scene_name, + scene_location, + scene_length, + scene_split, + data_idx, + ) = scene_record if ( (scene_location in scene_tag or "loo" in scene_split) diff --git a/src/trajdata/dataset_specific/nusc/nusc_dataset.py b/src/trajdata/dataset_specific/nusc/nusc_dataset.py index 5f74831..1d85749 100644 --- a/src/trajdata/dataset_specific/nusc/nusc_dataset.py +++ b/src/trajdata/dataset_specific/nusc/nusc_dataset.py @@ -95,7 +95,9 @@ def _get_matching_scenes_from_obj( # Saving all scene records for later caching. all_scenes_list.append( - NuscSceneRecord(scene_name, scene_location, scene_length, scene_desc, idx) + NuscSceneRecord( + scene_name, scene_location, scene_length, scene_desc, idx + ) ) if scene_location.split("-")[0] in scene_tag and scene_split in scene_tag: @@ -127,7 +129,13 @@ def _get_matching_scenes_from_cache( scenes_list: List[SceneMetadata] = list() for scene_record in all_scenes_list: - scene_name, scene_location, scene_length, scene_desc, data_idx = scene_record + ( + scene_name, + scene_location, + scene_length, + scene_desc, + data_idx, + ) = scene_record scene_split: str = self.metadata.scene_split_map[scene_name] if scene_location.split("-")[0] in scene_tag and scene_split in scene_tag: diff --git a/src/trajdata/dataset_specific/scene_records.py b/src/trajdata/dataset_specific/scene_records.py index d24205b..60d3abb 100644 --- a/src/trajdata/dataset_specific/scene_records.py +++ b/src/trajdata/dataset_specific/scene_records.py @@ -8,6 +8,7 @@ class EUPedsRecord(NamedTuple): split: str data_idx: int + class NuscSceneRecord(NamedTuple): name: str location: str @@ -15,6 +16,7 @@ class NuscSceneRecord(NamedTuple): desc: str data_idx: int + class LyftSceneRecord(NamedTuple): name: str length: str diff --git a/src/trajdata/simulation/sim_df_cache.py b/src/trajdata/simulation/sim_df_cache.py index d4eeaaa..99a0fc0 100644 --- a/src/trajdata/simulation/sim_df_cache.py +++ b/src/trajdata/simulation/sim_df_cache.py @@ -121,7 +121,7 @@ def append_state(self, xyh_dict: Dict[str, np.ndarray]) -> None: sim_step_df = pd.DataFrame(sim_dict) sim_step_df.set_index(["agent_id", "scene_ts"], inplace=True) - if self.scene_ts < self.scene.length_timesteps: + if self.scene_ts < self.scene.length_timesteps and self.scene_ts in self.persistent_data_df.index.get_level_values(1): self.persistent_data_df.drop(index=self.scene_ts, level=1, inplace=True) self.persistent_data_df = pd.concat([self.persistent_data_df, sim_step_df]) diff --git a/src/trajdata/simulation/sim_scene.py b/src/trajdata/simulation/sim_scene.py index 16b0c91..247a153 100644 --- a/src/trajdata/simulation/sim_scene.py +++ b/src/trajdata/simulation/sim_scene.py @@ -37,7 +37,6 @@ def __init__( self.init_scene_ts: int = init_timestep self.freeze_agents: bool = freeze_agents self.return_dict: bool = return_dict - self.scene_ts: int = self.init_scene_ts agents_present: List[AgentMetadata] = self.scene_info.agent_presence[ @@ -111,7 +110,6 @@ def get_obs( scene_time_agent = SceneTimeAgent( self.scene_info, self.scene_ts, self.agents, agent, self.cache ) - agent_data_list.append( AgentBatchElement( self.cache, @@ -124,6 +122,7 @@ def get_obs( incl_map=get_map and self.dataset.incl_map, map_params=self.dataset.map_params, standardize_data=self.dataset.standardize_data, + standardize_derivatives=self.dataset.standardize_derivatives, ) ) diff --git a/src/trajdata/utils/arr_utils.py b/src/trajdata/utils/arr_utils.py index 5c9e09d..9f3338c 100644 --- a/src/trajdata/utils/arr_utils.py +++ b/src/trajdata/utils/arr_utils.py @@ -79,6 +79,23 @@ def transform_matrices(angles: Tensor, translations: Tensor) -> Tensor: ) +def batch_nd_transform_points_np(points, Mat): + ndim = Mat.shape[-1] - 1 + batch = list(range(Mat.ndim - 2)) + [Mat.ndim - 1] + [Mat.ndim - 2] + Mat = np.transpose(Mat, batch) + if points.ndim == Mat.ndim - 1: + return (points[..., np.newaxis, :] @ Mat[..., :ndim, :ndim]).squeeze(-2) + Mat[ + ..., -1:, :ndim + ].squeeze(-2) + elif points.ndim == Mat.ndim: + return ( + (points[..., np.newaxis, :] @ Mat[..., np.newaxis, :ndim, :ndim]) + + Mat[..., np.newaxis, -1:, :ndim] + ).squeeze(-2) + else: + raise Exception("wrong shape") + + def agent_aware_diff(values: np.ndarray, agent_ids: np.ndarray) -> np.ndarray: values_diff: np.ndarray = np.diff( values, axis=0, prepend=values[[0]] - (values[[1]] - values[[0]]) @@ -97,3 +114,57 @@ def agent_aware_diff(values: np.ndarray, agent_ids: np.ndarray) -> np.ndarray: values_diff[border_mask] = values_diff[border_mask + 1] return values_diff + + +def batch_proj(x, line): + # x:[batch,3], line:[batch,N,3] + line_length = line.shape[-2] + batch_dim = x.ndim - 1 + if isinstance(x, torch.Tensor): + delta = line[..., 0:2] - torch.unsqueeze(x[..., 0:2], dim=-2).repeat( + *([1] * batch_dim), line_length, 1 + ) + dis = torch.linalg.norm(delta, axis=-1) + idx0 = torch.argmin(dis, dim=-1) + idx = idx0.view(*line.shape[:-2], 1, 1).repeat( + *([1] * (batch_dim + 1)), line.shape[-1] + ) + line_min = torch.squeeze(torch.gather(line, -2, idx), dim=-2) + dx = x[..., None, 0] - line[..., 0] + dy = x[..., None, 1] - line[..., 1] + delta_y = -dx * torch.sin(line_min[..., None, 2]) + dy * torch.cos( + line_min[..., None, 2] + ) + delta_x = dx * torch.cos(line_min[..., None, 2]) + dy * torch.sin( + line_min[..., None, 2] + ) + + delta_psi = angle_wrap(x[..., 2] - line_min[..., 2]) + + return ( + delta_x, + delta_y, + torch.unsqueeze(delta_psi, dim=-1), + ) + elif isinstance(x, np.ndarray): + delta = line[..., 0:2] - np.repeat( + x[..., np.newaxis, 0:2], line_length, axis=-2 + ) + dis = np.linalg.norm(delta, axis=-1) + idx0 = np.argmin(dis, axis=-1) + idx = idx0.reshape(*line.shape[:-2], 1, 1).repeat(line.shape[-1], axis=-1) + line_min = np.squeeze(np.take_along_axis(line, idx, axis=-2), axis=-2) + dx = x[..., None, 0] - line[..., 0] + dy = x[..., None, 1] - line[..., 1] + delta_y = -dx * np.sin(line_min[..., None, 2]) + dy * np.cos( + line_min[..., None, 2] + ) + delta_x = dx * np.cos(line_min[..., None, 2]) + dy * np.sin( + line_min[..., None, 2] + ) + delta_psi = angle_wrap(x[..., 2] - line_min[..., 2]) + return ( + delta_x, + delta_y, + np.expand_dims(delta_psi, axis=-1), + ) diff --git a/src/trajdata/visualization/vis.py b/src/trajdata/visualization/vis.py index 9d08f41..cea0abb 100644 --- a/src/trajdata/visualization/vis.py +++ b/src/trajdata/visualization/vis.py @@ -6,7 +6,7 @@ from torch import Tensor from trajdata.data_structures.agent import AgentType -from trajdata.data_structures.batch import AgentBatch +from trajdata.data_structures.batch import AgentBatch, SceneBatch from trajdata.data_structures.map import Map @@ -131,3 +131,118 @@ def plot_agent_batch( if close: plt.close() + + +def plot_scene_batch( + batch: SceneBatch, + batch_idx: int, + ax: Optional[Axes] = None, + show: bool = True, + close: bool = True, +) -> None: + if ax is None: + _, ax = plt.subplots() + + num_agents: int = batch.num_agents[batch_idx].item() + + history_xy: Tensor = batch.agent_hist[batch_idx].cpu() + center_xy: Tensor = batch.agent_hist[batch_idx, ..., -1, :2].cpu() + future_xy: Tensor = batch.agent_fut[batch_idx, ..., :2].cpu() + + if batch.maps is not None: + centered_agent_id: int = 0 + agent_from_world_tf: Tensor = batch.centered_agent_from_world_tf[ + batch_idx + ].cpu() + world_from_raster_tf: Tensor = torch.linalg.inv( + batch.rasters_from_world_tf[batch_idx, centered_agent_id].cpu() + ) + + agent_from_raster_tf: Tensor = agent_from_world_tf @ world_from_raster_tf + + patch_size: int = batch.maps[batch_idx, centered_agent_id].shape[-1] + + left_extent: float = (agent_from_raster_tf @ torch.tensor([0.0, 0.0, 1.0]))[ + 0 + ].item() + right_extent: float = ( + agent_from_raster_tf @ torch.tensor([patch_size, 0.0, 1.0]) + )[0].item() + bottom_extent: float = ( + agent_from_raster_tf @ torch.tensor([0.0, patch_size, 1.0]) + )[1].item() + top_extent: float = (agent_from_raster_tf @ torch.tensor([0.0, 0.0, 1.0]))[ + 1 + ].item() + + ax.imshow( + Map.to_img( + batch.maps[batch_idx, centered_agent_id].cpu(), + # [[0], [1], [2]] + # [[0, 1, 2], [3, 4], [5, 6]], + ), + extent=( + left_extent, + right_extent, + bottom_extent, + top_extent, + ), + alpha=0.3, + ) + + for agent_id in range(num_agents): + ax.plot( + history_xy[agent_id, ..., 0], + history_xy[agent_id, ..., 1], + c="orange", + ls="--", + label="Agent History" if agent_id == 0 else None, + ) + ax.quiver( + history_xy[agent_id, ..., 0], + history_xy[agent_id, ..., 1], + history_xy[agent_id, ..., -1], + history_xy[agent_id, ..., -2], + color="k", + ) + ax.plot( + future_xy[agent_id, ..., 0], + future_xy[agent_id, ..., 1], + c="violet", + label="Agent Future" if agent_id == 0 else None, + ) + ax.scatter( + center_xy[agent_id, 0], + center_xy[agent_id, 1], + s=20, + c="orangered", + label="Agent Current" if agent_id == 0 else None, + ) + + if batch.robot_fut is not None and batch.robot_fut.shape[1] > 0: + ax.plot( + batch.robot_fut[batch_idx, 1:, 0], + batch.robot_fut[batch_idx, 1:, 1], + label="Ego Future", + c="blue", + ) + ax.scatter( + batch.robot_fut[batch_idx, 0, 0], + batch.robot_fut[batch_idx, 0, 1], + s=20, + c="blue", + label="Ego Current", + ) + + ax.set_xlabel("x (m)") + ax.set_ylabel("y (m)") + + ax.grid(False) + ax.legend(loc="best", frameon=True) + ax.axis("equal") + + if show: + plt.show() + + if close: + plt.close()