-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adds observation term history support to Observation Manager #1439
Changes from 7 commits
a9b338d
f7ede99
9342d25
62ea6a6
8646503
a4d4818
da68a2b
d52992f
0f4a3a0
8b66c97
0cd10af
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -14,6 +14,7 @@ | |||||||||||||||||||||||||||||||||||||||
from typing import TYPE_CHECKING | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
from omni.isaac.lab.utils import modifiers | ||||||||||||||||||||||||||||||||||||||||
from omni.isaac.lab.utils.buffers import CircularBuffer | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
from .manager_base import ManagerBase, ManagerTermBase | ||||||||||||||||||||||||||||||||||||||||
from .manager_term_cfg import ObservationGroupCfg, ObservationTermCfg | ||||||||||||||||||||||||||||||||||||||||
|
@@ -174,12 +175,17 @@ def group_obs_concatenate(self) -> dict[str, bool]: | |||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
def reset(self, env_ids: Sequence[int] | None = None) -> dict[str, float]: | ||||||||||||||||||||||||||||||||||||||||
# call all terms that are classes | ||||||||||||||||||||||||||||||||||||||||
for group_cfg in self._group_obs_class_term_cfgs.values(): | ||||||||||||||||||||||||||||||||||||||||
for group_name, group_cfg in self._group_obs_class_term_cfgs.items(): | ||||||||||||||||||||||||||||||||||||||||
for term_cfg in group_cfg: | ||||||||||||||||||||||||||||||||||||||||
term_cfg.func.reset(env_ids=env_ids) | ||||||||||||||||||||||||||||||||||||||||
# reset terms with history | ||||||||||||||||||||||||||||||||||||||||
for term_name in self._group_obs_term_names[group_name]: | ||||||||||||||||||||||||||||||||||||||||
if term_name in self._group_obs_term_history_buffer[group_name]: | ||||||||||||||||||||||||||||||||||||||||
self._group_obs_term_history_buffer[group_name][term_name].reset(batch_ids=env_ids) | ||||||||||||||||||||||||||||||||||||||||
# call all modifiers that are classes | ||||||||||||||||||||||||||||||||||||||||
for mod in self._group_obs_class_modifiers: | ||||||||||||||||||||||||||||||||||||||||
mod.reset(env_ids=env_ids) | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
# nothing to log here | ||||||||||||||||||||||||||||||||||||||||
return {} | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
|
@@ -248,7 +254,7 @@ def compute_group(self, group_name: str) -> torch.Tensor | dict[str, torch.Tenso | |||||||||||||||||||||||||||||||||||||||
obs_terms = zip(group_term_names, self._group_obs_term_cfgs[group_name]) | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
# evaluate terms: compute, add noise, clip, scale, custom modifiers | ||||||||||||||||||||||||||||||||||||||||
for name, term_cfg in obs_terms: | ||||||||||||||||||||||||||||||||||||||||
for term_name, term_cfg in obs_terms: | ||||||||||||||||||||||||||||||||||||||||
# compute term's value | ||||||||||||||||||||||||||||||||||||||||
obs: torch.Tensor = term_cfg.func(self._env, **term_cfg.params).clone() | ||||||||||||||||||||||||||||||||||||||||
# apply post-processing | ||||||||||||||||||||||||||||||||||||||||
|
@@ -261,8 +267,17 @@ def compute_group(self, group_name: str) -> torch.Tensor | dict[str, torch.Tenso | |||||||||||||||||||||||||||||||||||||||
obs = obs.clip_(min=term_cfg.clip[0], max=term_cfg.clip[1]) | ||||||||||||||||||||||||||||||||||||||||
if term_cfg.scale is not None: | ||||||||||||||||||||||||||||||||||||||||
obs = obs.mul_(term_cfg.scale) | ||||||||||||||||||||||||||||||||||||||||
# add value to list | ||||||||||||||||||||||||||||||||||||||||
group_obs[name] = obs | ||||||||||||||||||||||||||||||||||||||||
# Update the history buffer if observation term has history enabled | ||||||||||||||||||||||||||||||||||||||||
if term_cfg.history_length > 0: | ||||||||||||||||||||||||||||||||||||||||
self._group_obs_term_history_buffer[group_name][term_name].append(obs) | ||||||||||||||||||||||||||||||||||||||||
if term_cfg.flatten_history_dim: | ||||||||||||||||||||||||||||||||||||||||
group_obs[term_name] = self._group_obs_term_history_buffer[group_name][term_name].buffer.reshape( | ||||||||||||||||||||||||||||||||||||||||
self._env.num_envs, -1 | ||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||||||||||
group_obs[term_name] = self._group_obs_term_history_buffer[group_name][term_name].buffer | ||||||||||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||||||||||
group_obs[term_name] = obs | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
# concatenate all observations in the group together | ||||||||||||||||||||||||||||||||||||||||
if self._group_obs_concatenate[group_name]: | ||||||||||||||||||||||||||||||||||||||||
|
@@ -283,7 +298,7 @@ def _prepare_terms(self): | |||||||||||||||||||||||||||||||||||||||
self._group_obs_term_cfgs: dict[str, list[ObservationTermCfg]] = dict() | ||||||||||||||||||||||||||||||||||||||||
self._group_obs_class_term_cfgs: dict[str, list[ObservationTermCfg]] = dict() | ||||||||||||||||||||||||||||||||||||||||
self._group_obs_concatenate: dict[str, bool] = dict() | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
self._group_obs_term_history_buffer: dict[str, dict] = dict() | ||||||||||||||||||||||||||||||||||||||||
# create a list to store modifiers that are classes | ||||||||||||||||||||||||||||||||||||||||
# we store it as a separate list to only call reset on them and prevent unnecessary calls | ||||||||||||||||||||||||||||||||||||||||
self._group_obs_class_modifiers: list[modifiers.ModifierBase] = list() | ||||||||||||||||||||||||||||||||||||||||
|
@@ -309,6 +324,7 @@ def _prepare_terms(self): | |||||||||||||||||||||||||||||||||||||||
self._group_obs_term_dim[group_name] = list() | ||||||||||||||||||||||||||||||||||||||||
self._group_obs_term_cfgs[group_name] = list() | ||||||||||||||||||||||||||||||||||||||||
self._group_obs_class_term_cfgs[group_name] = list() | ||||||||||||||||||||||||||||||||||||||||
group_entry_history_buffer: dict[str, CircularBuffer] = dict() | ||||||||||||||||||||||||||||||||||||||||
# read common config for the group | ||||||||||||||||||||||||||||||||||||||||
self._group_obs_concatenate[group_name] = group_cfg.concatenate_terms | ||||||||||||||||||||||||||||||||||||||||
# check if config is dict already | ||||||||||||||||||||||||||||||||||||||||
|
@@ -319,7 +335,7 @@ def _prepare_terms(self): | |||||||||||||||||||||||||||||||||||||||
# iterate over all the terms in each group | ||||||||||||||||||||||||||||||||||||||||
for term_name, term_cfg in group_cfg_items: | ||||||||||||||||||||||||||||||||||||||||
# skip non-obs settings | ||||||||||||||||||||||||||||||||||||||||
if term_name in ["enable_corruption", "concatenate_terms"]: | ||||||||||||||||||||||||||||||||||||||||
if term_name in ["enable_corruption", "concatenate_terms", "history_length", "flatten_history_dim"]: | ||||||||||||||||||||||||||||||||||||||||
continue | ||||||||||||||||||||||||||||||||||||||||
# check for non config | ||||||||||||||||||||||||||||||||||||||||
if term_cfg is None: | ||||||||||||||||||||||||||||||||||||||||
|
@@ -335,10 +351,18 @@ def _prepare_terms(self): | |||||||||||||||||||||||||||||||||||||||
# check noise settings | ||||||||||||||||||||||||||||||||||||||||
if not group_cfg.enable_corruption: | ||||||||||||||||||||||||||||||||||||||||
term_cfg.noise = None | ||||||||||||||||||||||||||||||||||||||||
# check group history params and override terms | ||||||||||||||||||||||||||||||||||||||||
if group_cfg.history_length is not None: | ||||||||||||||||||||||||||||||||||||||||
term_cfg.history_length = group_cfg.history_length | ||||||||||||||||||||||||||||||||||||||||
term_cfg.flatten_history_dim = group_cfg.flatten_history_dim | ||||||||||||||||||||||||||||||||||||||||
# add term config to list to list | ||||||||||||||||||||||||||||||||||||||||
self._group_obs_term_names[group_name].append(term_name) | ||||||||||||||||||||||||||||||||||||||||
self._group_obs_term_cfgs[group_name].append(term_cfg) | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
# create history buffers | ||||||||||||||||||||||||||||||||||||||||
if term_cfg.history_length > 0: | ||||||||||||||||||||||||||||||||||||||||
group_entry_history_buffer[term_name] = CircularBuffer( | ||||||||||||||||||||||||||||||||||||||||
max_len=term_cfg.history_length, batch_size=self._env.num_envs, device=self._env.device | ||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||
# call function the first time to fill up dimensions | ||||||||||||||||||||||||||||||||||||||||
obs_dims = tuple(term_cfg.func(self._env, **term_cfg.params).shape) | ||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
This populates the correct |
||||||||||||||||||||||||||||||||||||||||
self._group_obs_term_dim[group_name].append(obs_dims[1:]) | ||||||||||||||||||||||||||||||||||||||||
|
@@ -411,3 +435,5 @@ def _prepare_terms(self): | |||||||||||||||||||||||||||||||||||||||
self._group_obs_class_term_cfgs[group_name].append(term_cfg) | ||||||||||||||||||||||||||||||||||||||||
# call reset (in-case above call to get obs dims changed the state) | ||||||||||||||||||||||||||||||||||||||||
term_cfg.func.reset() | ||||||||||||||||||||||||||||||||||||||||
# add history buffers for each group | ||||||||||||||||||||||||||||||||||||||||
self._group_obs_term_history_buffer[group_name] = group_entry_history_buffer |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -75,6 +75,16 @@ def current_length(self) -> torch.Tensor: | |
""" | ||
return torch.minimum(self._num_pushes, self._max_len) | ||
|
||
@property | ||
def buffer(self) -> torch.Tensor: | ||
"""Complete circular buffer with most recent entry at the end and oldest entry at the beginning. | ||
Returns: | ||
Complete circular buffer with most recent entry at the end and oldest entry at the beginning of dimension 1. The shape is [batch_size, max_length, data.shape[1:]]. | ||
""" | ||
buf = self._buffer.clone() | ||
buf = torch.roll(buf, shifts=self.max_length - self._pointer - 1, dims=0) | ||
return torch.transpose(buf, dim0=0, dim1=1) | ||
|
||
""" | ||
Operations. | ||
""" | ||
|
@@ -89,8 +99,10 @@ def reset(self, batch_ids: Sequence[int] | None = None): | |
if batch_ids is None: | ||
batch_ids = slice(None) | ||
# reset the number of pushes for the specified batch indices | ||
# note: we don't need to reset the buffer since it will be overwritten. The pointer handles this. | ||
self._num_pushes[batch_ids] = 0 | ||
if self._buffer is not None: | ||
# set buffer at batch_id reset indices to 0.0 so that the buffer() getter returns the cleared circular buffer after reset. | ||
self._buffer[:, batch_ids, :] = 0.0 | ||
|
||
def append(self, data: torch.Tensor): | ||
"""Append the data to the circular buffer. | ||
|
@@ -109,7 +121,7 @@ def append(self, data: torch.Tensor): | |
# at the fist call, initialize the buffer | ||
if self._buffer is None: | ||
self._pointer = -1 | ||
self._buffer = torch.empty((self.max_length, *data.shape), dtype=data.dtype, device=self._device) | ||
self._buffer = torch.zeros((self.max_length, *data.shape), dtype=data.dtype, device=self._device) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @jtigue-bdai Isn't it better to initialize the buffer to the latest data on the first append or right after a reset? Otherwise, we assume that zero is a valid value for the observation, which may not always be true. I would suggest that, during reset or at first init, all the indices in the history is initialized to the most recent data, possibly the one being passed into There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thats a good point. I think you are right about filling on the first append. |
||
# move the head to the next slot | ||
self._pointer = (self._pointer + 1) % self.max_length | ||
# add the new data to the last layer | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I cannot find how to provide a suggestion to lines that are not part of the changelist, so here we go.
In the method
def __str__(self) -> str:
,You may need to re-compute the observation dimensions depending on whether flattening is enabled or not.
I see two deficiencies which can cause confusions:
self._group_obs_dim
is wrong and does not correspond to the actual observation dimension if history is used. This is quite critical to be fixed IMO.Hope this helps!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the catch I will address this.