Skip to content

Commit

Permalink
Update (base update)
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Jan 17, 2025
1 parent 256a700 commit dc4f6e5
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 8 deletions.
48 changes: 40 additions & 8 deletions torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2457,6 +2457,7 @@ def __init__(
shape: Union[torch.Size, int] = _DEFAULT_SHAPE,
device: Optional[DEVICE_TYPING] = None,
dtype: torch.dtype | None = None,
example_data: Any = None,
**kwargs,
):
if isinstance(shape, int):
Expand All @@ -2467,6 +2468,7 @@ def __init__(
super().__init__(
shape=shape, space=None, device=device, dtype=dtype, domain=domain, **kwargs
)
self.example_data = example_data

def cardinality(self) -> Any:
raise RuntimeError("Cannot enumerate a NonTensorSpec.")
Expand All @@ -2485,30 +2487,46 @@ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> NonTensor:
dest_device = torch.device(dest)
if dest_device == self.device and dest_dtype == self.dtype:
return self
return self.__class__(shape=self.shape, device=dest_device, dtype=None)
return self.__class__(
shape=self.shape,
device=dest_device,
dtype=None,
example_data=self.example_data,
)

def clone(self) -> NonTensor:
return self.__class__(shape=self.shape, device=self.device, dtype=self.dtype)
return self.__class__(
shape=self.shape,
device=self.device,
dtype=self.dtype,
example_data=self.example_data,
)

def rand(self, shape=None):
if shape is None:
shape = ()
return NonTensorData(
data=None, batch_size=(*shape, *self._safe_shape), device=self.device
data=self.example_data,
batch_size=(*shape, *self._safe_shape),
device=self.device,
)

def zero(self, shape=None):
if shape is None:
shape = ()
return NonTensorData(
data=None, batch_size=(*shape, *self._safe_shape), device=self.device
data=self.example_data,
batch_size=(*shape, *self._safe_shape),
device=self.device,
)

def one(self, shape=None):
if shape is None:
shape = ()
return NonTensorData(
data=None, batch_size=(*shape, *self._safe_shape), device=self.device
data=self.example_data,
batch_size=(*shape, *self._safe_shape),
device=self.device,
)

def is_in(self, val: Any) -> bool:
Expand All @@ -2533,23 +2551,36 @@ def expand(self, *shape):
raise ValueError(
f"The last elements of the expanded shape must match the current one. Got shape={shape} while self.shape={self.shape}."
)
return self.__class__(shape=shape, device=self.device, dtype=None)
return self.__class__(
shape=shape, device=self.device, dtype=None, example_data=self.example_data
)

def _reshape(self, shape):
return self.__class__(shape=shape, device=self.device, dtype=self.dtype)
return self.__class__(
shape=shape,
device=self.device,
dtype=self.dtype,
example_data=self.example_data,
)

def _unflatten(self, dim, sizes):
shape = torch.zeros(self.shape, device="meta").unflatten(dim, sizes).shape
return self.__class__(
shape=shape,
device=self.device,
dtype=self.dtype,
example_data=self.example_data,
)

def __getitem__(self, idx: SHAPE_INDEX_TYPING):
"""Indexes the current TensorSpec based on the provided index."""
indexed_shape = _size(_shape_indexing(self.shape, idx))
return self.__class__(shape=indexed_shape, device=self.device, dtype=self.dtype)
return self.__class__(
shape=indexed_shape,
device=self.device,
dtype=self.dtype,
example_data=self.example_data,
)

def unbind(self, dim: int = 0):
orig_dim = dim
Expand All @@ -2565,6 +2596,7 @@ def unbind(self, dim: int = 0):
shape=shape,
device=self.device,
dtype=self.dtype,
example_data=self.example_data,
)
for i in range(self.shape[dim])
)
Expand Down
11 changes: 11 additions & 0 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -795,6 +795,17 @@ def input_spec(self) -> TensorSpec:
input_spec = self.__dict__.get("_input_spec", None)
return input_spec

def rand_action(self, tensordict: Optional[TensorDictBase] = None) -> TensorDict:
if self.base_env.rand_action is not EnvBase.rand_action:
# TODO: this will fail if the transform modifies the input.
# For instance, if PendulumEnv overrides rand_action and we build a
# env = PendulumEnv().append_transform(ActionDiscretizer(num_intervals=4))
# env.rand_action will NOT have a discrete action!
# Getting a discrete action would require coding the inverse transform of an action within
# ActionDiscretizer (ie, float->int, not int->float).
return self.base_env.rand_action(tensordict)
return super().rand_action(tensordict)

def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
# No need to clone here because inv does it already
# tensordict = tensordict.clone(False)
Expand Down

0 comments on commit dc4f6e5

Please sign in to comment.