Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds observation term history support to Observation Manager #1439

Merged
merged 11 commits into from
Dec 16, 2024
2 changes: 1 addition & 1 deletion source/extensions/omni.isaac.lab/config/extension.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[package]

# Note: Semantic Versioning is used: https://semver.org/
version = "0.27.15"
version = "0.28.0"

# Description
title = "Isaac Lab framework for Robot Learning"
Expand Down
12 changes: 12 additions & 0 deletions source/extensions/omni.isaac.lab/docs/CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,18 @@
Changelog
---------

0.28.0 (2024-11-22)
~~~~~~~~~~~~~~~~~~~~

Added
^^^^^

* Added observation history computation to :class:`omni.isaac.lab.manager.observation_manager.ObservationManager`.
* Added ``history_length`` and ``flatten_history_dim`` configuration parameters to :class:`omni.isaac.lab.manager.manager_term_cfg.ObservationTermCfg`
* Added ``history_length`` and ``flatten_history_dim`` configuration parameters to :class:`omni.isaac.lab.manager.manager_term_cfg.ObservationGroupCfg`
* Added full buffer property to :class:`omni.isaac.lab.utils.buffers.circular_buffer.CircularBuffer`


0.27.15 (2024-11-09)
~~~~~~~~~~~~~~~~~~~~

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,16 @@ class ObservationTermCfg(ManagerTermBaseCfg):
please make sure the length of the tuple matches the dimensions of the tensor outputted from the term.
"""

history_length: int = 0
"""Number of past observations to store in the zero-initialized observation buffers.
Defaults to 0, which means that only the current data is stored (no history).
If flatten_history_dim is set to True, the source data of shape (N, H, D, ...) where N is the batch dimension and
H is the history length will be reshaped to a 2D tensor of shape (N, H*D*...). Otherwise, the data will be returned as is.
"""

flatten_history_dim: bool = True
"""Whether or not the observation manager should flatten history-based observation terms to a 2D (N, D) tensor. Defaults to True."""


@configclass
class ObservationGroupCfg:
Expand All @@ -181,6 +191,21 @@ class ObservationGroupCfg:
Otherwise, no corruption is applied.
"""

history_length: int | None = None
"""Number of past observation to store in the zero-initialized observation buffers for all observation terms in group.

This parameter will override ObservationTermCfg.history_length if set. ObservationGroupCfg.history_length defaults
to None. If None, each ObservationTermCfg.history_length will be controlled on a per term basis. See ObservationTermCfg
for details on history_length implementation.
"""

flatten_history_dim: bool = True
"""Flag to flatten history-based observation terms to a 2D (num_env, D) tensor for all observation terms in group.

This parameter will override all ObservationTermCfg.flatten_history_dim in the group if ObservationGroupCfg.history_length
is set. ObservationGroupCfg.flatten_history_dim defaults to True.
"""


##
# Event manager
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from typing import TYPE_CHECKING

from omni.isaac.lab.utils import modifiers
from omni.isaac.lab.utils.buffers import CircularBuffer

from .manager_base import ManagerBase, ManagerTermBase
from .manager_term_cfg import ObservationGroupCfg, ObservationTermCfg
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I cannot find how to provide a suggestion to lines that are not part of the changelist, so here we go.

In the method def __str__(self) -> str:,

You may need to re-compute the observation dimensions depending on whether flattening is enabled or not.
I see two deficiencies which can cause confusions:

  1. Currently, the printed summary does not handle history length while computing the observation dimension.
  2. The self._group_obs_dim is wrong and does not correspond to the actual observation dimension if history is used. This is quite critical to be fixed IMO.
  3. The shape of the observation group printed in the summary is wrong.

Hope this helps!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the catch I will address this.

Expand Down Expand Up @@ -174,12 +175,17 @@ def group_obs_concatenate(self) -> dict[str, bool]:

def reset(self, env_ids: Sequence[int] | None = None) -> dict[str, float]:
# call all terms that are classes
for group_cfg in self._group_obs_class_term_cfgs.values():
for group_name, group_cfg in self._group_obs_class_term_cfgs.items():
for term_cfg in group_cfg:
term_cfg.func.reset(env_ids=env_ids)
# reset terms with history
for term_name in self._group_obs_term_names[group_name]:
if term_name in self._group_obs_term_history_buffer[group_name]:
self._group_obs_term_history_buffer[group_name][term_name].reset(batch_ids=env_ids)
# call all modifiers that are classes
for mod in self._group_obs_class_modifiers:
mod.reset(env_ids=env_ids)

# nothing to log here
return {}

Expand Down Expand Up @@ -248,7 +254,7 @@ def compute_group(self, group_name: str) -> torch.Tensor | dict[str, torch.Tenso
obs_terms = zip(group_term_names, self._group_obs_term_cfgs[group_name])

# evaluate terms: compute, add noise, clip, scale, custom modifiers
for name, term_cfg in obs_terms:
for term_name, term_cfg in obs_terms:
# compute term's value
obs: torch.Tensor = term_cfg.func(self._env, **term_cfg.params).clone()
# apply post-processing
Expand All @@ -261,8 +267,17 @@ def compute_group(self, group_name: str) -> torch.Tensor | dict[str, torch.Tenso
obs = obs.clip_(min=term_cfg.clip[0], max=term_cfg.clip[1])
if term_cfg.scale is not None:
obs = obs.mul_(term_cfg.scale)
# add value to list
group_obs[name] = obs
# Update the history buffer if observation term has history enabled
if term_cfg.history_length > 0:
self._group_obs_term_history_buffer[group_name][term_name].append(obs)
if term_cfg.flatten_history_dim:
group_obs[term_name] = self._group_obs_term_history_buffer[group_name][term_name].buffer.reshape(
self._env.num_envs, -1
)
else:
group_obs[term_name] = self._group_obs_term_history_buffer[group_name][term_name].buffer
else:
group_obs[term_name] = obs

# concatenate all observations in the group together
if self._group_obs_concatenate[group_name]:
Expand All @@ -283,7 +298,7 @@ def _prepare_terms(self):
self._group_obs_term_cfgs: dict[str, list[ObservationTermCfg]] = dict()
self._group_obs_class_term_cfgs: dict[str, list[ObservationTermCfg]] = dict()
self._group_obs_concatenate: dict[str, bool] = dict()

self._group_obs_term_history_buffer: dict[str, dict] = dict()
# create a list to store modifiers that are classes
# we store it as a separate list to only call reset on them and prevent unnecessary calls
self._group_obs_class_modifiers: list[modifiers.ModifierBase] = list()
Expand All @@ -309,6 +324,7 @@ def _prepare_terms(self):
self._group_obs_term_dim[group_name] = list()
self._group_obs_term_cfgs[group_name] = list()
self._group_obs_class_term_cfgs[group_name] = list()
group_entry_history_buffer: dict[str, CircularBuffer] = dict()
# read common config for the group
self._group_obs_concatenate[group_name] = group_cfg.concatenate_terms
# check if config is dict already
Expand All @@ -319,7 +335,7 @@ def _prepare_terms(self):
# iterate over all the terms in each group
for term_name, term_cfg in group_cfg_items:
# skip non-obs settings
if term_name in ["enable_corruption", "concatenate_terms"]:
if term_name in ["enable_corruption", "concatenate_terms", "history_length", "flatten_history_dim"]:
continue
# check for non config
if term_cfg is None:
Expand All @@ -335,10 +351,18 @@ def _prepare_terms(self):
# check noise settings
if not group_cfg.enable_corruption:
term_cfg.noise = None
# check group history params and override terms
if group_cfg.history_length is not None:
term_cfg.history_length = group_cfg.history_length
term_cfg.flatten_history_dim = group_cfg.flatten_history_dim
# add term config to list to list
self._group_obs_term_names[group_name].append(term_name)
self._group_obs_term_cfgs[group_name].append(term_cfg)

# create history buffers
if term_cfg.history_length > 0:
group_entry_history_buffer[term_name] = CircularBuffer(
max_len=term_cfg.history_length, batch_size=self._env.num_envs, device=self._env.device
)
# call function the first time to fill up dimensions
obs_dims = tuple(term_cfg.func(self._env, **term_cfg.params).shape)
Copy link
Contributor

@aravindev aravindev Nov 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# create history buffers
if term_cfg.history_length > 0:
group_entry_history_buffer[term_name] = CircularBuffer(
max_len=term_cfg.history_length, batch_size=self._env.num_envs, device=self._env.device
)
# call function the first time to fill up dimensions
obs_dims = tuple(term_cfg.func(self._env, **term_cfg.params).shape)
# call function the first time to fill up dimensions
obs_dims = tuple(term_cfg.func(self._env, **term_cfg.params).shape)
# create history buffers
if term_cfg.history_length > 0:
group_entry_history_buffer[term_name] = CircularBuffer(
max_len=term_cfg.history_length,
batch_size=self._env.num_envs,
device=self._env.device,
)
obs_dims = (obs_dims[0], term_cfg.history_length * obs_dims[1], *obs_dims[2:])
if term_cfg.flatten_history_dim:
obs_dims = (obs_dims[0], np.prod(obs_dims[1:]),)

This populates the correct _group_obs_term_dim into the dictionary. Later, this is used to compute the _group_obs_dim however, I assume that its computation does not need any changes.

self._group_obs_term_dim[group_name].append(obs_dims[1:])
Expand Down Expand Up @@ -411,3 +435,5 @@ def _prepare_terms(self):
self._group_obs_class_term_cfgs[group_name].append(term_cfg)
# call reset (in-case above call to get obs dims changed the state)
term_cfg.func.reset()
# add history buffers for each group
self._group_obs_term_history_buffer[group_name] = group_entry_history_buffer
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,16 @@ def current_length(self) -> torch.Tensor:
"""
return torch.minimum(self._num_pushes, self._max_len)

@property
def buffer(self) -> torch.Tensor:
"""Complete circular buffer with most recent entry at the end and oldest entry at the beginning.
Returns:
Complete circular buffer with most recent entry at the end and oldest entry at the beginning of dimension 1. The shape is [batch_size, max_length, data.shape[1:]].
"""
buf = self._buffer.clone()
buf = torch.roll(buf, shifts=self.max_length - self._pointer - 1, dims=0)
return torch.transpose(buf, dim0=0, dim1=1)

"""
Operations.
"""
Expand All @@ -89,8 +99,10 @@ def reset(self, batch_ids: Sequence[int] | None = None):
if batch_ids is None:
batch_ids = slice(None)
# reset the number of pushes for the specified batch indices
# note: we don't need to reset the buffer since it will be overwritten. The pointer handles this.
self._num_pushes[batch_ids] = 0
if self._buffer is not None:
# set buffer at batch_id reset indices to 0.0 so that the buffer() getter returns the cleared circular buffer after reset.
self._buffer[:, batch_ids, :] = 0.0

def append(self, data: torch.Tensor):
"""Append the data to the circular buffer.
Expand All @@ -109,7 +121,7 @@ def append(self, data: torch.Tensor):
# at the fist call, initialize the buffer
if self._buffer is None:
self._pointer = -1
self._buffer = torch.empty((self.max_length, *data.shape), dtype=data.dtype, device=self._device)
self._buffer = torch.zeros((self.max_length, *data.shape), dtype=data.dtype, device=self._device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jtigue-bdai Isn't it better to initialize the buffer to the latest data on the first append or right after a reset?

Otherwise, we assume that zero is a valid value for the observation, which may not always be true.
For example, if we are gathering a history of, lets say, the gravitational force, which may be defined as a strictly negative value, having a buffer filled with zeros as previous observations may not be within the expected distribution.

I would suggest that, during reset or at first init, all the indices in the history is initialized to the most recent data, possibly the one being passed into append()

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thats a good point. I think you are right about filling on the first append.

# move the head to the next slot
self._pointer = (self._pointer + 1) % self.max_length
# add the new data to the last layer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,159 @@ class ImageCfg(ObservationGroupCfg):
torch.testing.assert_close(obs_policy[:, 5:8], obs_critic[:, 0:3])
torch.testing.assert_close(obs_policy[:, 8:11], obs_critic[:, 3:6])

def test_compute_with_history(self):
"""Test the observation computation with history buffers."""
HISTORY_LENGTH = 5

@configclass
class MyObservationManagerCfg:
"""Test config class for observation manager."""

@configclass
class PolicyCfg(ObservationGroupCfg):
"""Test config class for policy observation group."""

term_1 = ObservationTermCfg(func=grilled_chicken, history_length=HISTORY_LENGTH)
# total observation size: term_dim (4) * history_len (5) = 20
term_2 = ObservationTermCfg(func=lin_vel_w_data)
# total observation size: term_dim (3) = 3

policy: ObservationGroupCfg = PolicyCfg()

# create observation manager
cfg = MyObservationManagerCfg()
self.obs_man = ObservationManager(cfg, self.env)
# compute observation using manager
observations = self.obs_man.compute()
# obtain the group observations
obs_policy: torch.Tensor = observations["policy"]
# check the observation shape
self.assertEqual((self.env.num_envs, 23), obs_policy.shape)
# check the observation data
expected_obs_term_1_data = torch.zeros(self.env.num_envs, 4 * HISTORY_LENGTH, device=self.env.device)
expected_obs_term_1_data[:, -4:] = torch.ones(4, device=self.env.device)
expected_obs_term_2_data = lin_vel_w_data(self.env)
expected_obs_data_t0 = torch.concat((expected_obs_term_1_data, expected_obs_term_2_data), dim=-1)
self.assertTrue(torch.equal(expected_obs_data_t0, obs_policy))
# test that the history buffer holds previous data
for _ in range(HISTORY_LENGTH):
observations = self.obs_man.compute()
obs_policy = observations["policy"]
expected_obs_term_1_data = torch.ones(self.env.num_envs, 4 * HISTORY_LENGTH, device=self.env.device)
expected_obs_data_t5 = torch.concat((expected_obs_term_1_data, expected_obs_term_2_data), dim=-1)
self.assertTrue(torch.equal(expected_obs_data_t5, obs_policy))
# test reset
self.obs_man.reset()
observations = self.obs_man.compute()
obs_policy = observations["policy"]
self.assertTrue(torch.equal(expected_obs_data_t0, obs_policy))
# test reset of specific env ids
reset_env_ids = [2, 4, 16]
self.obs_man.reset(reset_env_ids)
self.assertTrue(torch.equal(expected_obs_data_t0[reset_env_ids], obs_policy[reset_env_ids]))

def test_compute_with_2d_history(self):
"""Test the observation computation with history buffers for 2D observations."""
HISTORY_LENGTH = 5

@configclass
class MyObservationManagerCfg:
"""Test config class for observation manager."""

@configclass
class FlattenedPolicyCfg(ObservationGroupCfg):
"""Test config class for policy observation group."""

term_1 = ObservationTermCfg(
func=grilled_chicken_image, params={"bland": 1.0, "channel": 1}, history_length=HISTORY_LENGTH
)
# total observation size: term_dim (128, 256) * history_len (5) = 163840

@configclass
class PolicyCfg(ObservationGroupCfg):
"""Test config class for policy observation group."""

term_1 = ObservationTermCfg(
func=grilled_chicken_image,
params={"bland": 1.0, "channel": 1},
history_length=HISTORY_LENGTH,
flatten_history_dim=False,
)
# total observation size: (5, 128, 256, 1)

flat_obs_policy: ObservationGroupCfg = FlattenedPolicyCfg()
policy: ObservationGroupCfg = PolicyCfg()

# create observation manager
cfg = MyObservationManagerCfg()
self.obs_man = ObservationManager(cfg, self.env)
# compute observation using manager
observations = self.obs_man.compute()
# obtain the group observations
obs_policy_flat: torch.Tensor = observations["flat_obs_policy"]
obs_policy: torch.Tensor = observations["policy"]
# check the observation shapes
self.assertEqual((self.env.num_envs, 163840), obs_policy_flat.shape)
self.assertEqual((self.env.num_envs, HISTORY_LENGTH, 128, 256, 1), obs_policy.shape)

def test_compute_with_group_history(self):
"""Test the observation computation with group level history buffer configuration."""
TERM_HISTORY_LENGTH = 5
GROUP_HISTORY_LENGTH = 10

@configclass
class MyObservationManagerCfg:
"""Test config class for observation manager."""

@configclass
class PolicyCfg(ObservationGroupCfg):
"""Test config class for policy observation group."""

history_length = GROUP_HISTORY_LENGTH
# group level history length will override all terms
term_1 = ObservationTermCfg(func=grilled_chicken, history_length=TERM_HISTORY_LENGTH)
# total observation size: term_dim (4) * history_len (5) = 20
# with override total obs size: term_dim (4) * history_len (10) = 40
term_2 = ObservationTermCfg(func=lin_vel_w_data)
# total observation size: term_dim (3) = 3
# with override total obs size: term_dim (3) * history_len (10) = 30

policy: ObservationGroupCfg = PolicyCfg()

# create observation manager
cfg = MyObservationManagerCfg()
self.obs_man = ObservationManager(cfg, self.env)
# compute observation using manager
observations = self.obs_man.compute()
# obtain the group observations
obs_policy: torch.Tensor = observations["policy"]
# check the total observation shape
self.assertEqual((self.env.num_envs, 70), obs_policy.shape)
# check the observation data
expected_obs_term_1_data = torch.zeros(self.env.num_envs, 4 * GROUP_HISTORY_LENGTH, device=self.env.device)
expected_obs_term_1_data[:, -4:] = torch.ones(4, device=self.env.device)
expected_obs_term_2_data = torch.zeros(self.env.num_envs, 3 * GROUP_HISTORY_LENGTH, device=self.env.device)
expected_obs_term_2_data[:, -3:] = lin_vel_w_data(self.env)
expected_obs_data_t0 = torch.concat((expected_obs_term_1_data, expected_obs_term_2_data), dim=-1)
self.assertTrue(torch.equal(expected_obs_data_t0, obs_policy))
# test that the history buffer holds previous data
for _ in range(GROUP_HISTORY_LENGTH):
observations = self.obs_man.compute()
obs_policy = observations["policy"]
expected_obs_term_1_data = torch.ones(self.env.num_envs, 4 * GROUP_HISTORY_LENGTH, device=self.env.device)
expected_obs_term_2_data = lin_vel_w_data(self.env).repeat(1, GROUP_HISTORY_LENGTH)
expected_obs_data_t10 = torch.concat((expected_obs_term_1_data, expected_obs_term_2_data), dim=-1)
self.assertTrue(torch.equal(expected_obs_data_t10, obs_policy))
# test reset
self.obs_man.reset()
observations = self.obs_man.compute()
obs_policy = observations["policy"]
self.assertTrue(torch.equal(expected_obs_data_t0, obs_policy))
# test reset of specific env ids
reset_env_ids = [2, 4, 16]
self.obs_man.reset(reset_env_ids)
self.assertTrue(torch.equal(expected_obs_data_t0[reset_env_ids], obs_policy[reset_env_ids]))

def test_invalid_observation_config(self):
"""Test the invalid observation config."""

Expand Down
Loading
Loading