Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Undo fallback in abstract_dynamics.py #54

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 21 additions & 12 deletions waymax/dynamics/abstract_dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ def apply_trajectory_update_to_state(
is_controlled: jax.Array,
timestep: int,
allow_object_injection: bool = False,
use_fallback: bool = False,
) -> datatypes.Trajectory:
"""Applies a TrajectoryUpdate to the sim trajectory at the next timestep.

Expand All @@ -150,7 +151,7 @@ def apply_trajectory_update_to_state(

For objects not in is_controlled, reference_trajectory is used.
For objects in is_controlled, but not valid in trajectory_update, fall back to
constant speed behaviour.
constant speed behaviour if the use_fallback flag is on.

Args:
trajectory_update: Updated trajectory fields for all objects after the
Expand All @@ -168,6 +169,8 @@ def apply_trajectory_update_to_state(
allow_object_injection: Whether to allow new objects to enter the scene. If
this is set to False, all objects that are not valid at the current
timestep will not be valid at the next timestep and visa versa.
use_fallback: Whether to fall back to constant speed if a controlled agent
is given an invalid action. Otherwise, the agent will be invalidated.

Returns:
Updated trajectory given update from a dynamics model at `timestep` + 1.
Expand Down Expand Up @@ -198,19 +201,25 @@ def apply_trajectory_update_to_state(
# Some fields such as the length, width and height of objects are set to be
# the same for every timestep during data loading and so we don't update these
# from the current trajectory.
# TODO: Update z using the (x, y) coordinates of the vehicle.
# Update z using the (x, y) coordinates of the vehicle.
replacement_dict = {}
for field in CONTROLLABLE_FIELDS:
# Use fallback trajectory if user doesn't not provide valid action.
new_value = jnp.where(
trajectory_update.valid,
trajectory_update[field],
fallback_trajectory[field],
)
# Only update for is_controlled objects from users.
replacement_dict[field] = jnp.where(
is_controlled, new_value, default_next_traj[field]
)
if use_fallback:
# Use fallback trajectory if user doesn't not provide valid action.
new_value = jnp.where(
trajectory_update.valid,
trajectory_update[field],
fallback_trajectory[field],
)
# Only update for is_controlled objects from users.
replacement_dict[field] = jnp.where(
is_controlled, new_value, default_next_traj[field]
)
else:
new_value = jnp.where(
is_controlled, trajectory_update[field], default_next_traj[field]
)
replacement_dict[field] = new_value

exist_and_controlled = is_controlled & current_traj.valid
# For exist_and_controlled objects, valid flags should remain the same as
Expand Down
Loading