Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds observation term history support to Observation Manager #1439

Merged
merged 11 commits into from
Dec 16, 2024
2 changes: 1 addition & 1 deletion source/extensions/omni.isaac.lab/config/extension.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[package]

# Note: Semantic Versioning is used: https://semver.org/
version = "0.27.29"
version = "0.28.0"

# Description
title = "Isaac Lab framework for Robot Learning"
12 changes: 12 additions & 0 deletions source/extensions/omni.isaac.lab/docs/CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,18 @@
Changelog
---------

0.28.0 (2024-12-15)
~~~~~~~~~~~~~~~~~~~

Added
^^^^^

* Added observation history computation to :class:`omni.isaac.lab.manager.observation_manager.ObservationManager`.
* Added ``history_length`` and ``flatten_history_dim`` configuration parameters to :class:`omni.isaac.lab.manager.manager_term_cfg.ObservationTermCfg`
* Added ``history_length`` and ``flatten_history_dim`` configuration parameters to :class:`omni.isaac.lab.manager.manager_term_cfg.ObservationGroupCfg`
* Added full buffer property to :class:`omni.isaac.lab.utils.buffers.circular_buffer.CircularBuffer`


0.27.29 (2024-12-15)
~~~~~~~~~~~~~~~~~~~~

Original file line number Diff line number Diff line change
@@ -180,6 +180,19 @@ class ObservationTermCfg(ManagerTermBaseCfg):
please make sure the length of the tuple matches the dimensions of the tensor outputted from the term.
"""

history_length: int = 0
"""Number of past observations to store in the observation buffers. Defaults to 0, meaning no history.

Observation history initializes to empty, but is filled with the first append after reset or initialization. Subsequent history
only adds a single entry to the history buffer. If flatten_history_dim is set to True, the source data of shape
(N, H, D, ...) where N is the batch dimension and H is the history length will be reshaped to a 2D tensor of shape
(N, H*D*...). Otherwise, the data will be returned as is.
"""

flatten_history_dim: bool = True
"""Whether or not the observation manager should flatten history-based observation terms to a 2D (N, D) tensor.
Defaults to True."""


@configclass
class ObservationGroupCfg:
@@ -201,6 +214,22 @@ class ObservationGroupCfg:
Otherwise, no corruption is applied.
"""

history_length: int | None = None
"""Number of past observation to store in the observation buffers for all observation terms in group.

This parameter will override :attr:`ObservationTermCfg.history_length` if set. Defaults to None. If None, each
terms history will be controlled on a per term basis. See :class:`ObservationTermCfg` for details on history_length
implementation.
"""

flatten_history_dim: bool = True
"""Flag to flatten history-based observation terms to a 2D (num_env, D) tensor for all observation terms in group.
Defaults to True.

This parameter will override all :attr:`ObservationTermCfg.flatten_history_dim` in the group if
ObservationGroupCfg.history_length is set.
"""


##
# Event manager
Original file line number Diff line number Diff line change
@@ -8,12 +8,14 @@
from __future__ import annotations

import inspect
import numpy as np
import torch
from collections.abc import Sequence
from prettytable import PrettyTable
from typing import TYPE_CHECKING

from omni.isaac.lab.utils import modifiers
from omni.isaac.lab.utils.buffers import CircularBuffer

from .manager_base import ManagerBase, ManagerTermBase
from .manager_term_cfg import ObservationGroupCfg, ObservationTermCfg
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I cannot find how to provide a suggestion to lines that are not part of the changelist, so here we go.

In the method def __str__(self) -> str:,

You may need to re-compute the observation dimensions depending on whether flattening is enabled or not.
I see two deficiencies which can cause confusions:

  1. Currently, the printed summary does not handle history length while computing the observation dimension.
  2. The self._group_obs_dim is wrong and does not correspond to the actual observation dimension if history is used. This is quite critical to be fixed IMO.
  3. The shape of the observation group printed in the summary is wrong.

Hope this helps!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the catch I will address this.

@@ -45,6 +47,11 @@ class ObservationManager(ManagerBase):
concatenated. In this case, please set the :attr:`ObservationGroupCfg.concatenate_terms` attribute in the
group configuration to False.

Observations can also have history. This means a running history is updated per sim step. History can be controlled
per :class:`ObservationTermCfg` (See the :attr:`ObservationTermCfg.history_length` and
:attr:`ObservationTermCfg.flatten_history_dim`). History can also be controlled via :class:`ObservationGroupCfg`
where group configuration overwrites per term configuration if set. History follows an oldest to newest ordering.

The observation manager can be used to compute observations for all the groups or for a specific group. The
observations are computed by calling the registered functions for each term in the group. The functions are
called in the order of the terms in the group. The functions are expected to return a tensor with shape
@@ -174,12 +181,17 @@ def group_obs_concatenate(self) -> dict[str, bool]:

def reset(self, env_ids: Sequence[int] | None = None) -> dict[str, float]:
# call all terms that are classes
for group_cfg in self._group_obs_class_term_cfgs.values():
for group_name, group_cfg in self._group_obs_class_term_cfgs.items():
for term_cfg in group_cfg:
term_cfg.func.reset(env_ids=env_ids)
# reset terms with history
for term_name in self._group_obs_term_names[group_name]:
if term_name in self._group_obs_term_history_buffer[group_name]:
self._group_obs_term_history_buffer[group_name][term_name].reset(batch_ids=env_ids)
# call all modifiers that are classes
for mod in self._group_obs_class_modifiers:
mod.reset(env_ids=env_ids)

# nothing to log here
return {}

@@ -248,7 +260,7 @@ def compute_group(self, group_name: str) -> torch.Tensor | dict[str, torch.Tenso
obs_terms = zip(group_term_names, self._group_obs_term_cfgs[group_name])

# evaluate terms: compute, add noise, clip, scale, custom modifiers
for name, term_cfg in obs_terms:
for term_name, term_cfg in obs_terms:
# compute term's value
obs: torch.Tensor = term_cfg.func(self._env, **term_cfg.params).clone()
# apply post-processing
@@ -261,8 +273,17 @@ def compute_group(self, group_name: str) -> torch.Tensor | dict[str, torch.Tenso
obs = obs.clip_(min=term_cfg.clip[0], max=term_cfg.clip[1])
if term_cfg.scale is not None:
obs = obs.mul_(term_cfg.scale)
# add value to list
group_obs[name] = obs
# Update the history buffer if observation term has history enabled
if term_cfg.history_length > 0:
self._group_obs_term_history_buffer[group_name][term_name].append(obs)
if term_cfg.flatten_history_dim:
group_obs[term_name] = self._group_obs_term_history_buffer[group_name][term_name].buffer.reshape(
self._env.num_envs, -1
)
else:
group_obs[term_name] = self._group_obs_term_history_buffer[group_name][term_name].buffer
else:
group_obs[term_name] = obs

# concatenate all observations in the group together
if self._group_obs_concatenate[group_name]:
@@ -283,7 +304,7 @@ def _prepare_terms(self):
self._group_obs_term_cfgs: dict[str, list[ObservationTermCfg]] = dict()
self._group_obs_class_term_cfgs: dict[str, list[ObservationTermCfg]] = dict()
self._group_obs_concatenate: dict[str, bool] = dict()

self._group_obs_term_history_buffer: dict[str, dict] = dict()
# create a list to store modifiers that are classes
# we store it as a separate list to only call reset on them and prevent unnecessary calls
self._group_obs_class_modifiers: list[modifiers.ModifierBase] = list()
@@ -309,6 +330,7 @@ def _prepare_terms(self):
self._group_obs_term_dim[group_name] = list()
self._group_obs_term_cfgs[group_name] = list()
self._group_obs_class_term_cfgs[group_name] = list()
group_entry_history_buffer: dict[str, CircularBuffer] = dict()
# read common config for the group
self._group_obs_concatenate[group_name] = group_cfg.concatenate_terms
# check if config is dict already
@@ -319,7 +341,7 @@ def _prepare_terms(self):
# iterate over all the terms in each group
for term_name, term_cfg in group_cfg_items:
# skip non-obs settings
if term_name in ["enable_corruption", "concatenate_terms"]:
if term_name in ["enable_corruption", "concatenate_terms", "history_length", "flatten_history_dim"]:
continue
# check for non config
if term_cfg is None:
@@ -335,12 +357,26 @@ def _prepare_terms(self):
# check noise settings
if not group_cfg.enable_corruption:
term_cfg.noise = None
# check group history params and override terms
if group_cfg.history_length is not None:
term_cfg.history_length = group_cfg.history_length
term_cfg.flatten_history_dim = group_cfg.flatten_history_dim
# add term config to list to list
self._group_obs_term_names[group_name].append(term_name)
self._group_obs_term_cfgs[group_name].append(term_cfg)

# call function the first time to fill up dimensions
obs_dims = tuple(term_cfg.func(self._env, **term_cfg.params).shape)
# create history buffers and calculate history term dimensions
if term_cfg.history_length > 0:
group_entry_history_buffer[term_name] = CircularBuffer(
max_len=term_cfg.history_length, batch_size=self._env.num_envs, device=self._env.device
)
old_dims = list(obs_dims)
old_dims.insert(1, term_cfg.history_length)
obs_dims = tuple(old_dims)
if term_cfg.flatten_history_dim:
obs_dims = (obs_dims[0], np.prod(obs_dims[1:]))

self._group_obs_term_dim[group_name].append(obs_dims[1:])

# if scale is set, check if single float or tuple
@@ -411,3 +447,5 @@ def _prepare_terms(self):
self._group_obs_class_term_cfgs[group_name].append(term_cfg)
# call reset (in-case above call to get obs dims changed the state)
term_cfg.func.reset()
# add history buffers for each group
self._group_obs_term_history_buffer[group_name] = group_entry_history_buffer
Original file line number Diff line number Diff line change
@@ -75,6 +75,16 @@ def current_length(self) -> torch.Tensor:
"""
return torch.minimum(self._num_pushes, self._max_len)

@property
def buffer(self) -> torch.Tensor:
"""Complete circular buffer with most recent entry at the end and oldest entry at the beginning.
Returns:
Complete circular buffer with most recent entry at the end and oldest entry at the beginning of dimension 1. The shape is [batch_size, max_length, data.shape[1:]].
"""
buf = self._buffer.clone()
buf = torch.roll(buf, shifts=self.max_length - self._pointer - 1, dims=0)
return torch.transpose(buf, dim0=0, dim1=1)

"""
Operations.
"""
@@ -89,8 +99,10 @@ def reset(self, batch_ids: Sequence[int] | None = None):
if batch_ids is None:
batch_ids = slice(None)
# reset the number of pushes for the specified batch indices
# note: we don't need to reset the buffer since it will be overwritten. The pointer handles this.
self._num_pushes[batch_ids] = 0
if self._buffer is not None:
# set buffer at batch_id reset indices to 0.0 so that the buffer() getter returns the cleared circular buffer after reset.
self._buffer[:, batch_ids, :] = 0.0

def append(self, data: torch.Tensor):
"""Append the data to the circular buffer.
@@ -106,15 +118,20 @@ def append(self, data: torch.Tensor):
if data.shape[0] != self.batch_size:
raise ValueError(f"The input data has {data.shape[0]} environments while expecting {self.batch_size}")

# at the fist call, initialize the buffer
# at the first call, initialize the buffer size
if self._buffer is None:
self._pointer = -1
self._buffer = torch.empty((self.max_length, *data.shape), dtype=data.dtype, device=self._device)
# move the head to the next slot
self._pointer = (self._pointer + 1) % self.max_length
# add the new data to the last layer
self._buffer[self._pointer] = data.to(self._device)
# increment number of number of pushes
# Check for batches with zero pushes and initialize all values in batch to first append
if 0 in self._num_pushes.tolist():
fill_ids = [i for i, x in enumerate(self._num_pushes.tolist()) if x == 0]
self._num_pushes.tolist().index(0) if 0 in self._num_pushes.tolist() else None
self._buffer[:, fill_ids, :] = data.to(self._device)[fill_ids]
# increment number of number of pushes for all batches
self._num_pushes += 1

def __getitem__(self, key: torch.Tensor) -> torch.Tensor:
Original file line number Diff line number Diff line change
@@ -131,8 +131,51 @@ class SampleGroupCfg(ObservationGroupCfg):
self.obs_man = ObservationManager(cfg, self.env)
self.assertEqual(len(self.obs_man.active_terms["policy"]), 5)
# print the expected string
obs_man_str = str(self.obs_man)
print()
print(self.obs_man)
print(obs_man_str)
obs_man_str_split = obs_man_str.split("|")
term_1_str_index = obs_man_str_split.index(" term_1 ")
term_1_str_shape = obs_man_str_split[term_1_str_index + 1].strip()
self.assertEqual(term_1_str_shape, "(4,)")

def test_str_with_history(self):
"""Test the string representation of the observation manager with history terms."""

TERM_1_HISTORY = 5

@configclass
class MyObservationManagerCfg:
"""Test config class for observation manager."""

@configclass
class SampleGroupCfg(ObservationGroupCfg):
"""Test config class for policy observation group."""

term_1 = ObservationTermCfg(func="__main__:grilled_chicken", scale=10, history_length=TERM_1_HISTORY)
term_2 = ObservationTermCfg(func=grilled_chicken, scale=2)
term_3 = ObservationTermCfg(func=grilled_chicken_with_bbq, scale=5, params={"bbq": True})
term_4 = ObservationTermCfg(
func=grilled_chicken_with_yoghurt, scale=1.0, params={"hot": False, "bland": 2.0}
)
term_5 = ObservationTermCfg(
func=grilled_chicken_with_yoghurt_and_bbq, scale=1.0, params={"hot": False, "bland": 2.0}
)

policy: ObservationGroupCfg = SampleGroupCfg()

# create observation manager
cfg = MyObservationManagerCfg()
self.obs_man = ObservationManager(cfg, self.env)
self.assertEqual(len(self.obs_man.active_terms["policy"]), 5)
# print the expected string
obs_man_str = str(self.obs_man)
print()
print(obs_man_str)
obs_man_str_split = obs_man_str.split("|")
term_1_str_index = obs_man_str_split.index(" term_1 ")
term_1_str_shape = obs_man_str_split[term_1_str_index + 1].strip()
self.assertEqual(term_1_str_shape, "(20,)")

def test_config_equivalence(self):
"""Test the equivalence of observation manager created from different config types."""
@@ -304,6 +347,157 @@ class ImageCfg(ObservationGroupCfg):
torch.testing.assert_close(obs_policy[:, 5:8], obs_critic[:, 0:3])
torch.testing.assert_close(obs_policy[:, 8:11], obs_critic[:, 3:6])

def test_compute_with_history(self):
"""Test the observation computation with history buffers."""
HISTORY_LENGTH = 5

@configclass
class MyObservationManagerCfg:
"""Test config class for observation manager."""

@configclass
class PolicyCfg(ObservationGroupCfg):
"""Test config class for policy observation group."""

term_1 = ObservationTermCfg(func=grilled_chicken, history_length=HISTORY_LENGTH)
# total observation size: term_dim (4) * history_len (5) = 20
term_2 = ObservationTermCfg(func=lin_vel_w_data)
# total observation size: term_dim (3) = 3

policy: ObservationGroupCfg = PolicyCfg()

# create observation manager
cfg = MyObservationManagerCfg()
self.obs_man = ObservationManager(cfg, self.env)
# compute observation using manager
observations = self.obs_man.compute()
# obtain the group observations
obs_policy: torch.Tensor = observations["policy"]
# check the observation shape
self.assertEqual((self.env.num_envs, 23), obs_policy.shape)
# check the observation data
expected_obs_term_1_data = torch.ones(self.env.num_envs, 4 * HISTORY_LENGTH, device=self.env.device)
expected_obs_term_2_data = lin_vel_w_data(self.env)
expected_obs_data_t0 = torch.concat((expected_obs_term_1_data, expected_obs_term_2_data), dim=-1)
print(expected_obs_data_t0, obs_policy)
self.assertTrue(torch.equal(expected_obs_data_t0, obs_policy))
# test that the history buffer holds previous data
for _ in range(HISTORY_LENGTH):
observations = self.obs_man.compute()
obs_policy = observations["policy"]
expected_obs_term_1_data = torch.ones(self.env.num_envs, 4 * HISTORY_LENGTH, device=self.env.device)
expected_obs_data_t5 = torch.concat((expected_obs_term_1_data, expected_obs_term_2_data), dim=-1)
self.assertTrue(torch.equal(expected_obs_data_t5, obs_policy))
# test reset
self.obs_man.reset()
observations = self.obs_man.compute()
obs_policy = observations["policy"]
self.assertTrue(torch.equal(expected_obs_data_t0, obs_policy))
# test reset of specific env ids
reset_env_ids = [2, 4, 16]
self.obs_man.reset(reset_env_ids)
self.assertTrue(torch.equal(expected_obs_data_t0[reset_env_ids], obs_policy[reset_env_ids]))

def test_compute_with_2d_history(self):
"""Test the observation computation with history buffers for 2D observations."""
HISTORY_LENGTH = 5

@configclass
class MyObservationManagerCfg:
"""Test config class for observation manager."""

@configclass
class FlattenedPolicyCfg(ObservationGroupCfg):
"""Test config class for policy observation group."""

term_1 = ObservationTermCfg(
func=grilled_chicken_image, params={"bland": 1.0, "channel": 1}, history_length=HISTORY_LENGTH
)
# total observation size: term_dim (128, 256) * history_len (5) = 163840

@configclass
class PolicyCfg(ObservationGroupCfg):
"""Test config class for policy observation group."""

term_1 = ObservationTermCfg(
func=grilled_chicken_image,
params={"bland": 1.0, "channel": 1},
history_length=HISTORY_LENGTH,
flatten_history_dim=False,
)
# total observation size: (5, 128, 256, 1)

flat_obs_policy: ObservationGroupCfg = FlattenedPolicyCfg()
policy: ObservationGroupCfg = PolicyCfg()

# create observation manager
cfg = MyObservationManagerCfg()
self.obs_man = ObservationManager(cfg, self.env)
# compute observation using manager
observations = self.obs_man.compute()
# obtain the group observations
obs_policy_flat: torch.Tensor = observations["flat_obs_policy"]
obs_policy: torch.Tensor = observations["policy"]
# check the observation shapes
self.assertEqual((self.env.num_envs, 163840), obs_policy_flat.shape)
self.assertEqual((self.env.num_envs, HISTORY_LENGTH, 128, 256, 1), obs_policy.shape)

def test_compute_with_group_history(self):
"""Test the observation computation with group level history buffer configuration."""
TERM_HISTORY_LENGTH = 5
GROUP_HISTORY_LENGTH = 10

@configclass
class MyObservationManagerCfg:
"""Test config class for observation manager."""

@configclass
class PolicyCfg(ObservationGroupCfg):
"""Test config class for policy observation group."""

history_length = GROUP_HISTORY_LENGTH
# group level history length will override all terms
term_1 = ObservationTermCfg(func=grilled_chicken, history_length=TERM_HISTORY_LENGTH)
# total observation size: term_dim (4) * history_len (5) = 20
# with override total obs size: term_dim (4) * history_len (10) = 40
term_2 = ObservationTermCfg(func=lin_vel_w_data)
# total observation size: term_dim (3) = 3
# with override total obs size: term_dim (3) * history_len (10) = 30

policy: ObservationGroupCfg = PolicyCfg()

# create observation manager
cfg = MyObservationManagerCfg()
self.obs_man = ObservationManager(cfg, self.env)
# compute observation using manager
observations = self.obs_man.compute()
# obtain the group observations
obs_policy: torch.Tensor = observations["policy"]
# check the total observation shape
self.assertEqual((self.env.num_envs, 70), obs_policy.shape)
# check the observation data is initialized properly
expected_obs_term_1_data = torch.ones(self.env.num_envs, 4 * GROUP_HISTORY_LENGTH, device=self.env.device)
expected_obs_term_2_data = lin_vel_w_data(self.env).repeat(1, GROUP_HISTORY_LENGTH)
expected_obs_data_t0 = torch.concat((expected_obs_term_1_data, expected_obs_term_2_data), dim=-1)
self.assertTrue(torch.equal(expected_obs_data_t0, obs_policy))
# test that the history buffer holds previous data
for _ in range(GROUP_HISTORY_LENGTH):
observations = self.obs_man.compute()
obs_policy = observations["policy"]
expected_obs_term_1_data = torch.ones(self.env.num_envs, 4 * GROUP_HISTORY_LENGTH, device=self.env.device)
expected_obs_term_2_data = lin_vel_w_data(self.env).repeat(1, GROUP_HISTORY_LENGTH)
expected_obs_data_t10 = torch.concat((expected_obs_term_1_data, expected_obs_term_2_data), dim=-1)
self.assertTrue(torch.equal(expected_obs_data_t10, obs_policy))
# test reset
self.obs_man.reset()
observations = self.obs_man.compute()
obs_policy = observations["policy"]
self.assertTrue(torch.equal(expected_obs_data_t0, obs_policy))
# test reset of specific env ids
reset_env_ids = [2, 4, 16]
self.obs_man.reset(reset_env_ids)
self.assertTrue(torch.equal(expected_obs_data_t0[reset_env_ids], obs_policy[reset_env_ids]))

def test_invalid_observation_config(self):
"""Test the invalid observation config."""

Original file line number Diff line number Diff line change
@@ -46,9 +46,31 @@ def test_reset(self):
# reset the buffer
self.buffer.reset()

# check if the buffer is empty
# check if the buffer has zeros entries
self.assertEqual(self.buffer.current_length.tolist(), [0, 0, 0])

def test_reset_subset(self):
"""Test resetting a subset of batches in the circular buffer."""
data1 = torch.ones((self.batch_size, 2), device=self.device)
data2 = 2.0 * data1.clone()
data3 = 3.0 * data1.clone()
self.buffer.append(data1)
self.buffer.append(data2)
# reset the buffer
reset_batch_id = 1
self.buffer.reset(batch_ids=[reset_batch_id])
# check that correct batch is reset
self.assertEqual(self.buffer.current_length.tolist()[reset_batch_id], 0)
# Append new set of data
self.buffer.append(data3)
# check if the correct number of entries are in each batch
expected_length = [3, 3, 3]
expected_length[reset_batch_id] = 1
self.assertEqual(self.buffer.current_length.tolist(), expected_length)
# check that all entries of the recently reset and appended batch are equal
for i in range(self.max_len):
torch.testing.assert_close(self.buffer.buffer[reset_batch_id, 0], self.buffer.buffer[reset_batch_id, i])

def test_append_and_retrieve(self):
"""Test appending and retrieving data from the circular buffer."""
# append some data
@@ -121,6 +143,33 @@ def test_key_greater_than_pushes(self):
retrieved_data = self.buffer[torch.tensor([5, 5, 5], device=self.device)]
self.assertTrue(torch.equal(retrieved_data, data1))

def test_return_buffer_prop(self):
"""Test retrieving the whole buffer for correct size and contents.
Returning the whole buffer should have the shape [batch_size,max_len,data.shape[1:]]
"""
num_overflow = 2
for i in range(self.buffer.max_length + num_overflow):
data = torch.tensor([[i]], device=self.device).repeat(3, 2)
self.buffer.append(data)

retrieved_buffer = self.buffer.buffer
# check shape
self.assertTrue(retrieved_buffer.shape == torch.Size([self.buffer.batch_size, self.buffer.max_length, 2]))
# check that batch is first dimension
torch.testing.assert_close(retrieved_buffer[0], retrieved_buffer[1])
# check oldest
torch.testing.assert_close(
retrieved_buffer[:, 0], torch.tensor([[num_overflow]], device=self.device).repeat(3, 2)
)
# check most recent
torch.testing.assert_close(
retrieved_buffer[:, -1],
torch.tensor([[self.buffer.max_length + num_overflow - 1]], device=self.device).repeat(3, 2),
)
# check that it is returned oldest first
for idx in range(self.buffer.max_length - 1):
self.assertTrue(torch.all(torch.le(retrieved_buffer[:, idx], retrieved_buffer[:, idx + 1])))


if __name__ == "__main__":
run_tests()