Skip to content

Commit

Permalink
Adds support for drift in RayCaster and makes fields in `ContactSen…
Browse files Browse the repository at this point in the history
…sorData` optional (#201)

# Description

This MR adds support for 2D drift into the `RayCaster`. It also makes
certain attributes in the `ContactSensorData` optional since they are
not needed by default.

## Type of change

- Bug fix (non-breaking change which fixes an issue)
- New feature (non-breaking change which adds functionality)
- Breaking change (fix or feature that would cause existing
functionality to not work as expected)
- This change requires a documentation update

## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./orbit.sh --format`
- [ ] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [x] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
  • Loading branch information
Mayankm96 authored Oct 24, 2023
1 parent c6af307 commit 514baa4
Show file tree
Hide file tree
Showing 8 changed files with 145 additions and 45 deletions.
2 changes: 1 addition & 1 deletion source/extensions/omni.isaac.orbit/config/extension.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[package]

# Note: Semantic Versioning is used: https://semver.org/
version = "0.9.13"
version = "0.9.14"

# Description
title = "ORBIT framework for Robot Learning"
Expand Down
18 changes: 18 additions & 0 deletions source/extensions/omni.isaac.orbit/docs/CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,24 @@
Changelog
---------

0.9.14 (2023-10-21)
~~~~~~~~~~~~~~~~~~~

Added
^^^^^

* Added 2-D drift (i.e. along x and y) to the :class:`omni.isaac.orbit.sensors.RayCaster` class.
* Added flags to the :class:`omni.isaac.orbit.sensors.ContactSensorCfg` to optionally obtain the
sensor origin and air time information. Since these are not required by default, they are
disabled by default.

Fixed
^^^^^

* Fixed the handling of contact sensor history buffer in the :class:`omni.isaac.orbit.sensors.ContactSensor` class.
Earlier, the buffer was not being updated correctly.


0.9.13 (2023-10-20)
~~~~~~~~~~~~~~~~~~~

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
#
# SPDX-License-Identifier: BSD-3-Clause

# Ignore optional memory usage warning globally
# pyright: reportOptionalSubscript=false

from __future__ import annotations

Expand Down Expand Up @@ -137,11 +139,17 @@ def reset(self, env_ids: Sequence[int] | None = None):
if env_ids is None:
env_ids = slice(None)
# reset accumulative data buffers
self._data.current_air_time[env_ids] = 0.0
self._data.last_air_time[env_ids] = 0.0
self._data.net_forces_w[env_ids] = 0.0
# reset the data history
self._data.net_forces_w_history[env_ids] = 0.0
if self.cfg.history_length > 0:
self._data.net_forces_w_history[env_ids] = 0.0
# reset force matrix
if len(self.cfg.filter_prim_paths_expr) != 0:
self._data.force_matrix_w[env_ids] = 0.0
# reset the current air time
if self.cfg.track_air_time:
self._data.current_air_time[env_ids] = 0.0
self._data.last_air_time[env_ids] = 0.0
# Set all reset sensors to not outdated since their value won't be updated till next sim step.
self._is_outdated[env_ids] = False

Expand Down Expand Up @@ -202,15 +210,24 @@ def _initialize_impl(self):
f"\n\tResolved prim paths: {body_names_regex}"
)

# fill the data buffer
self._data.pos_w = torch.zeros(self._num_envs, self._num_bodies, 3, device=self._device)
self._data.quat_w = torch.zeros(self._num_envs, self._num_bodies, 4, device=self._device)
self._data.last_air_time = torch.zeros(self._num_envs, self._num_bodies, device=self._device)
self._data.current_air_time = torch.zeros(self._num_envs, self._num_bodies, device=self._device)
# prepare data buffers
self._data.net_forces_w = torch.zeros(self._num_envs, self._num_bodies, 3, device=self._device)
self._data.net_forces_w_history = torch.zeros(
self._num_envs, self.cfg.history_length + 1, self._num_bodies, 3, device=self._device
)
# optional buffers
# -- history of net forces
if self.cfg.history_length > 0:
self._data.net_forces_w_history = torch.zeros(
self._num_envs, self.cfg.history_length, self._num_bodies, 3, device=self._device
)
else:
self._data.net_forces_w_history = self._data.net_forces_w.unsqueeze(1)
# -- pose of sensor origins
if self.cfg.track_pose:
self._data.pos_w = torch.zeros(self._num_envs, self._num_bodies, 3, device=self._device)
self._data.quat_w = torch.zeros(self._num_envs, self._num_bodies, 4, device=self._device)
# -- air time between contacts
if self.cfg.track_air_time:
self._data.last_air_time = torch.zeros(self._num_envs, self._num_bodies, device=self._device)
self._data.current_air_time = torch.zeros(self._num_envs, self._num_bodies, device=self._device)
# force matrix: (num_sensors, num_bodies, num_shapes, num_filter_shapes, 3)
if len(self.cfg.filter_prim_paths_expr) != 0:
num_shapes = self.contact_physx_view.sensor_count // self._num_bodies
Expand All @@ -224,17 +241,16 @@ def _update_buffers_impl(self, env_ids: Sequence[int]):
# default to all sensors
if len(env_ids) == self._num_envs:
env_ids = slice(None)
# obtain the poses of the sensors:
# TODO decide if we really to track poses -- This is the body's CoM. Not contact location.
pose = self.body_physx_view.get_transforms()
self._data.pos_w[env_ids] = pose.view(-1, self._num_bodies, 7)[env_ids, :, :3]
self._data.quat_w[env_ids] = pose.view(-1, self._num_bodies, 7)[env_ids, :, 3:]

# obtain the contact forces
# TODO: We are handling the indexing ourself because of the shape; (N, B) vs expected (N * B).
# This isn't the most efficient way to do this, but it's the easiest to implement.
net_forces_w = self.contact_physx_view.get_net_contact_forces(dt=self._sim_physics_dt)
self._data.net_forces_w[env_ids, :, :] = net_forces_w.view(-1, self._num_bodies, 3)[env_ids]
# update contact force history
if self.cfg.history_length > 0:
self._data.net_forces_w_history[env_ids, 1:] = self._data.net_forces_w_history[env_ids, :-1].clone()
self._data.net_forces_w_history[env_ids, 0] = self._data.net_forces_w[env_ids]

# obtain the contact force matrix
if len(self.cfg.filter_prim_paths_expr) != 0:
Expand All @@ -245,26 +261,25 @@ def _update_buffers_impl(self, env_ids: Sequence[int]):
force_matrix_w = self.contact_physx_view.get_contact_force_matrix(dt=self._sim_physics_dt)
force_matrix_w = force_matrix_w.view(-1, self._num_bodies, num_shapes, num_filters, 3)
self._data.force_matrix_w[env_ids] = force_matrix_w[env_ids]

# update contact force history
previous_net_forces_w = self._data.net_forces_w_history.clone()
self._data.net_forces_w_history[env_ids, 0, :, :] = self._data.net_forces_w[env_ids, :, :]
if self.cfg.history_length > 0:
self._data.net_forces_w_history[env_ids, 1:, :, :] = previous_net_forces_w[env_ids, :-1, :, :]

# contact state
# -- time elapsed since last update
# since this function is called every frame, we can use the difference to get the elapsed time
elapsed_time = self._timestamp[env_ids] - self._timestamp_last_update[env_ids]
# -- check contact state of bodies
is_contact = torch.norm(self._data.net_forces_w[env_ids, :, :], dim=-1) > 1.0
is_first_contact = (self._data.current_air_time[env_ids] > 0) * is_contact
# -- update ongoing timer for bodies air
self._data.current_air_time[env_ids] += elapsed_time.unsqueeze(-1)
# -- update time for the last time bodies were in contact
self._data.last_air_time[env_ids] = self._data.current_air_time[env_ids] * is_first_contact
# -- increment timers for bodies that are not in contact
self._data.current_air_time[env_ids] *= ~is_contact
# obtain the pose of the sensor origin
if self.cfg.track_pose:
pose = self.body_physx_view.get_transforms()
self._data.pos_w[env_ids] = pose.view(-1, self._num_bodies, 7)[env_ids, :, :3]
self._data.quat_w[env_ids] = pose.view(-1, self._num_bodies, 7)[env_ids, :, 3:]
# obtain the air time
if self.cfg.track_air_time:
# -- time elapsed since last update
# since this function is called every frame, we can use the difference to get the elapsed time
elapsed_time = self._timestamp[env_ids] - self._timestamp_last_update[env_ids]
# -- check contact state of bodies
is_contact = torch.norm(self._data.net_forces_w[env_ids, :, :], dim=-1) > 1.0
is_first_contact = (self._data.current_air_time[env_ids] > 0) * is_contact
# -- update ongoing timer for bodies air
self._data.current_air_time[env_ids] += elapsed_time.unsqueeze(-1)
# -- update time for the last time bodies were in contact
self._data.last_air_time[env_ids] = self._data.current_air_time[env_ids] * is_first_contact
# -- increment timers for bodies that are not in contact
self._data.current_air_time[env_ids] *= ~is_contact

def _debug_vis_impl(self):
# visualize the contacts
Expand All @@ -276,4 +291,10 @@ def _debug_vis_impl(self):
net_contact_force_w = torch.norm(self._data.net_forces_w, dim=-1)
marker_indices = torch.where(net_contact_force_w > 1.0, 0, 1)
# check if prim is visualized
self.contact_visualizer.visualize(self._data.pos_w.view(-1, 3), marker_indices=marker_indices.view(-1))
if self.cfg.track_pose:
frame_origins: torch.Tensor = self._data.pos_w
else:
pose = self.body_physx_view.get_transforms()
frame_origins = pose.view(-1, self._num_bodies, 7)[:, :, :3]
# visualize
self.contact_visualizer.visualize(frame_origins.view(-1, 3), marker_indices=marker_indices.view(-1))
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@ class ContactSensorCfg(SensorBaseCfg):

class_type: type = ContactSensor

track_pose: bool = False
"""Whether to track the pose of the sensor's origin. Defaults to False."""

track_air_time: bool = False
"""Whether to track the air time of the bodies (time between contacts). Defaults to False."""

filter_prim_paths_expr: list[str] = list()
"""The list of primitive paths to filter contacts with.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,30 @@
class ContactSensorData:
"""Data container for the contact reporting sensor."""

pos_w: torch.Tensor = None
pos_w: torch.Tensor | None = None
"""Position of the sensor origin in world frame.
Shape is (N, 3), where ``N`` is the number of sensors.
Note:
If the :attr:`ContactSensorCfg.track_pose` is False, then this qunatity is None.
"""
quat_w: torch.Tensor = None

quat_w: torch.Tensor | None = None
"""Orientation of the sensor origin in quaternion ``(w, x, y, z)`` in world frame.
Shape is (N, 4), where ``N`` is the number of sensors.
Note:
If the :attr:`ContactSensorCfg.track_pose` is False, then this qunatity is None.
"""

net_forces_w: torch.Tensor = None
"""The net contact forces in world frame.
Shape is (N, B, 3), where ``N`` is the number of sensors and ``B`` is the number of bodies in each sensor.
"""

net_forces_w_history: torch.Tensor = None
"""The net contact forces in world frame.
Expand All @@ -38,22 +46,30 @@ class ContactSensorData:
In the history dimension, the first index is the most recent and the last index is the oldest.
"""

force_matrix_w: torch.Tensor = None
force_matrix_w: torch.Tensor | None = None
"""The contact forces filtered between the sensor bodies and filtered bodies in world frame.
Shape is (N, B, S, M, 3), where ``N`` is the number of sensors, ``B`` is number of bodies in each sensor,
``S`` is number of shapes per body and ``M`` is the number of filtered bodies.
If the :attr:`ContactSensorCfg.filter_prim_paths_expr` is empty, then this tensor will be empty.
Note:
If the :attr:`ContactSensorCfg.filter_prim_paths_expr` is empty, then this quantity is None.
"""

last_air_time: torch.Tensor = None
last_air_time: torch.Tensor | None = None
"""Time spent (in s) in the air before the last contact.
Shape is (N,), where ``N`` is the number of sensors.
Note:
If the :attr:`ContactSensorCfg.track_air_time` is False, then this quantity is None.
"""
current_air_time: torch.Tensor = None

current_air_time: torch.Tensor | None = None
"""Time spent (in s) in the air since the last contact.
Shape is (N,), where ``N`` is the number of sensors.
Note:
If the :attr:`ContactSensorCfg.track_air_time` is False, then this quantity is None.
"""
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,15 @@ def set_debug_vis(self, debug_vis: bool):
if self.ray_visualizer is not None:
self.ray_visualizer.set_visibility(debug_vis)

def reset(self, env_ids: Sequence[int] | None = None):
# reset the timers and counters
super().reset(env_ids)
# resolve None
if env_ids is None:
env_ids = slice(None)
# resample the drift
self.drift[env_ids].uniform_(*self.cfg.drift_range)

"""
Implementation.
"""
Expand Down Expand Up @@ -180,7 +189,8 @@ def _initialize_impl(self):
# repeat the rays for each sensor
self.ray_starts = self.ray_starts.repeat(self._view.count, 1, 1)
self.ray_directions = self.ray_directions.repeat(self._view.count, 1, 1)

# prepare drift
self.drift = torch.zeros(self._view.count, 3, device=self.device)
# fill the data buffer
self._data.pos_w = torch.zeros(self._view.count, 3, device=self._device)
self._data.quat_w = torch.zeros(self._view.count, 4, device=self._device)
Expand All @@ -190,6 +200,7 @@ def _update_buffers_impl(self, env_ids: Sequence[int]):
"""Fills the buffers of the sensor data."""
# obtain the poses of the sensors
pos_w, quat_w = self._view.get_world_poses(env_ids, clone=False)
pos_w += self.drift[env_ids]
self._data.pos_w[env_ids] = pos_w
self._data.quat_w[env_ids] = quat_w

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,9 @@ class OffsetCfg:

max_distance: float = 100.0
"""Maximum distance (in meters) from the sensor to ray cast to. Defaults to 100.0."""

drift_range: tuple[float, float] = (0.0, 0.0)
"""The range of drift (in meters) to add to the ray starting positions (xyz). Defaults to (0.0, 0.0).
For floating base robots, this is useful for simulating drift in the robot's pose estimation.
"""
22 changes: 22 additions & 0 deletions source/extensions/omni.isaac.orbit/test/isaacsim/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,28 @@ def test_array_slicing(self):
self.assertEqual(my_tensor[slice(None), 0, 0].shape, (400,))
self.assertEqual(my_tensor[:, 0, 0].shape, (400,))

def test_array_copying(self):
"""Check how indexing effects the returned tensor."""

size = (400, 300, 5)
my_tensor = torch.rand(size, device="cuda:0")

# obtain a slice of the tensor
my_slice = my_tensor[0, ...]
self.assertEqual(my_slice.untyped_storage().data_ptr(), my_tensor.untyped_storage().data_ptr())

# obtain a slice over ranges
my_slice = my_tensor[0:2, ...]
self.assertEqual(my_slice.untyped_storage().data_ptr(), my_tensor.untyped_storage().data_ptr())

# obtain a slice over list
my_slice = my_tensor[[0, 1], ...]
self.assertNotEqual(my_slice.untyped_storage().data_ptr(), my_tensor.untyped_storage().data_ptr())

# obtain a slice over tensor
my_slice = my_tensor[torch.tensor([0, 1]), ...]
self.assertNotEqual(my_slice.untyped_storage().data_ptr(), my_tensor.untyped_storage().data_ptr())


if __name__ == "__main__":
unittest.main()

0 comments on commit 514baa4

Please sign in to comment.