Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: array IndexError when center_agent_idx is not set as sdc #50

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

SS47816
Copy link

@SS47816 SS47816 commented Feb 1, 2024

Fixed the behavior of the function plot_simulator_state() when the parameter center_agent_idx in viz_config is not -1 # sdc.

Original Issue

When the parameter center_agent_idx in viz_config is set to a user-specified index, the following issue will arise:

File [~/anaconda3/envs/waymax/lib/python3.10/site-packages/waymax/visualization/viz.py:287], in plot_simulator_state(state, use_log_traj, viz_config, batch_idx, highlight_obj)
    285 else:
    286   xy = current_xy[viz_config.center_agent_idx]
--> 287 origin_x, origin_y = xy[0, :2]
    288 ax.axis((
    289     origin_x - viz_config.back_x,
    290     origin_x + viz_config.front_x,
    291     origin_y - viz_config.back_y,
    292     origin_y + viz_config.front_y,
    293 ))
    295 return utils.img_from_fig(fig)

File [~/anaconda3/envs/waymax/lib/python3.10/site-packages/jax/_src/array.py:314], in ArrayImpl.__getitem__(self, idx)
    312   num_idx = sum(e is not None and e is not Ellipsis for e in idx)
    313   if num_idx > self.ndim:
--> 314     raise IndexError(
    315         f"Too many indices for array: array has ndim of {self.ndim}, but "
...
    316         f"was indexed with {num_idx} non-None[/Ellipsis](https://file+.vscode-resource.vscode-cdn.net/Ellipsis) indices.")
    318 if isinstance(self.sharding, PmapSharding):
    319   if not isinstance(idx, tuple):

IndexError: Too many indices for array: array has ndim of 1, but was indexed with 2 non-None[/Ellipsis](https://file+.vscode-resource.vscode-cdn.net/Ellipsis) indices.

This issue occurred because the shape of the xy = current_xy[state.object_metadata.is_sdc] is [1, 2] (2-dimension) whereas the shape of the xy = current_xy[viz_config.center_agent_idx] is [2] (1-dimension).

Fix

The following code in visualization/viz.py, from the line 280 to 294:

  # 3. Gets np img, centered on selected agent's current location.
  # [A, 2]
  current_xy = traj.xy[:, state.timestep, :]
  if viz_config.center_agent_idx == -1:
    xy = current_xy[state.object_metadata.is_sdc]
  else:
    xy = current_xy[viz_config.center_agent_idx]
  origin_x, origin_y = xy[0, :2]
  ax.axis((
      origin_x - viz_config.back_x,
      origin_x + viz_config.front_x,
      origin_y - viz_config.back_y,
      origin_y + viz_config.front_y,
  ))

has been changed to:

  # 3. Gets np img, centered on selected agent's current location.
  # [A, 2]
  current_xy = traj.xy[:, state.timestep, :]
  if viz_config.center_agent_idx == -1:
    xy = current_xy[state.object_metadata.is_sdc]
    origin_x, origin_y = xy[0, :2]
  else:
    xy = current_xy[viz_config.center_agent_idx]
    origin_x, origin_y = xy[:2]
  ax.axis((
      origin_x - viz_config.back_x,
      origin_x + viz_config.front_x,
      origin_y - viz_config.back_y,
      origin_y + viz_config.front_y,
  ))

so that the shapes of the origin_x, origin_y, and xy in the if-else statement are now properly aligned.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant