Replies: 1 comment
-
The source code is here. class _CompositeSpecItemsView:
"""Wrapper class that enables richer behaviour of `items` for CompositeSpec."""
def __init__(
self,
composite: CompositeSpec,
include_nested,
leaves_only,
*,
is_leaf,
):
self.composite = composite
self.leaves_only = leaves_only
self.include_nested = include_nested
self.is_leaf = is_leaf
def __iter__(self):
from tensordict.base import _NESTED_TENSORS_AS_LISTS
is_leaf = self.is_leaf
if is_leaf in (None, _NESTED_TENSORS_AS_LISTS):
def _is_leaf(cls):
return not issubclass(cls, CompositeSpec)
else:
_is_leaf = is_leaf
def _iter_from_item(key, item):
if self.include_nested and isinstance(item, CompositeSpec):
for subkey, subitem in item.items(
include_nested=True,
leaves_only=self.leaves_only,
is_leaf=is_leaf,
):
if not isinstance(subkey, tuple):
subkey = (subkey,)
yield (key, *subkey), subitem
if not self.leaves_only and not _is_leaf(type(item)):
yield (key, item)
elif not self.leaves_only or _is_leaf(type(item)):
yield key, item
for key, item in self._get_composite_items(is_leaf):
if is_leaf is _NESTED_TENSORS_AS_LISTS and isinstance(
item, _LazyStackedMixin
):
for (i, spec) in enumerate(item._specs):
yield from _iter_from_item(unravel_key((key, str(i))), spec)
else:
yield from _iter_from_item(key, item)
def _get_composite_items(self, is_leaf):
if isinstance(self.composite, LazyStackedCompositeSpec):
from tensordict.base import _NESTED_TENSORS_AS_LISTS
if is_leaf is _NESTED_TENSORS_AS_LISTS:
for i, spec in enumerate(self.composite._specs):
for key, item in spec.items():
yield ((str(i), key), item)
else:
keys = self.composite._specs[0].keys()
keys = set(keys)
for spec in self.composite._specs[1:]:
keys = keys.intersection(spec.keys())
yield from ((key, self.composite[key]) for key in sorted(keys, key=str))
else:
yield from self.composite._specs.items()
def __len__(self):
i = 0
for _ in self:
i += 1
return i
def __repr__(self):
return f"{type(self).__name__}(keys={list(self)})"
def __contains__(self, item):
item = unravel_key(item)
if len(item) == 1:
item = item[0]
for key in self.__iter__():
if key == item:
return True
else:
return False
def _keys(self):
return _CompositeSpecKeysView(self)
def _values(self):
return _CompositeSpecValuesView(self) |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
I have a question about the order of keys returned by
CompositeSpec.keys(include_nested=True, leaves_only=False)
. Is there any guarantee on the order of the keys, specifically that the iterator iterates from the deepest nodes to the root?I want to write a function that removes all empty entries from a
CompositeSpec
:As shown in the function above, it assumes that the iterator processes from the deepest nodes to the root. This assumption is crucial for the function to work correctly.
If there is no such guarantee, how can I achieve this goal? A
CompositeSpec
with empty entries is quite undesirable.Thank you!
Beta Was this translation helpful? Give feedback.
All reactions