From 7a44821ef91a3ac51f963ac9da70c4dd34be809d Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 8 Nov 2024 14:40:57 +0000 Subject: [PATCH] Update (base update) [ghstack-poisoned] --- test/test_specs.py | 5 +++++ torchrl/data/tensor_specs.py | 22 +++++++++++----------- 2 files changed, 16 insertions(+), 11 deletions(-) diff --git a/test/test_specs.py b/test/test_specs.py index 1a7dd41621e..39b09798ac2 100644 --- a/test/test_specs.py +++ b/test/test_specs.py @@ -3823,6 +3823,7 @@ def test_discrete(self): spec.enumerate() == torch.tensor([[0, 0, 0], [1, 1, 1], [2, 2, 2], [3, 3, 3], [4, 4, 4]]) ).all() + assert spec.is_in(spec.enumerate()) def test_one_hot(self): spec = OneHotDiscreteTensorSpec(n=5, shape=(2, 5)) @@ -3839,15 +3840,18 @@ def test_one_hot(self): dtype=torch.bool, ) ).all() + assert spec.is_in(spec.enumerate()) def test_multi_discrete(self): spec = MultiDiscreteTensorSpec([3, 4, 5], shape=(2, 3)) enum = spec.enumerate() + assert spec.is_in(enum) assert enum.shape == torch.Size([60, 2, 3]) def test_multi_onehot(self): spec = MultiOneHotDiscreteTensorSpec([3, 4, 5], shape=(2, 12)) enum = spec.enumerate() + assert spec.is_in(enum) assert enum.shape == torch.Size([60, 2, 12]) def test_composite(self): @@ -3859,6 +3863,7 @@ def test_composite(self): shape=[3], ) c_enum = c.enumerate() + assert c.is_in(c_enum) assert c_enum.shape == torch.Size((20, 3)) assert c_enum["b"].shape == torch.Size((20, 3)) diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index b641b808cf3..3590d76d2ce 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -835,7 +835,7 @@ def contains(self, item: torch.Tensor | TensorDictBase) -> bool: return self.is_in(item) @abc.abstractmethod - def enumerate(self): + def enumerate(self) -> Any: """Returns all the samples that can be obtained from the TensorSpec. The samples will be stacked along the first dimension. @@ -1281,7 +1281,7 @@ def __eq__(self, other): return False return True - def enumerate(self): + def enumerate(self) -> torch.Tensor | TensorDictBase: return torch.stack( [spec.enumerate() for spec in self._specs], dim=self.stack_dim + 1 ) @@ -1747,7 +1747,7 @@ def to_numpy(self, val: torch.Tensor, safe: bool = None) -> np.ndarray: return np.array(vals).reshape(tuple(val.shape)) return val - def enumerate(self): + def enumerate(self) -> torch.Tensor: return ( torch.eye(self.n, dtype=self.dtype, device=self.device) .expand(*self.shape, self.n) @@ -2078,7 +2078,7 @@ def __init__( domain=domain, ) - def enumerate(self): + def enumerate(self) -> Any: raise NotImplementedError( f"enumerate is not implemented for spec of class {type(self).__name__}." ) @@ -2402,7 +2402,7 @@ def __init__( shape=shape, space=None, device=device, dtype=dtype, domain=domain, **kwargs ) - def enumerate(self): + def enumerate(self) -> Any: raise NotImplementedError("Cannot enumerate a NonTensorSpec.") def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> NonTensor: @@ -2641,7 +2641,7 @@ def is_in(self, val: torch.Tensor) -> bool: def _project(self, val: torch.Tensor) -> torch.Tensor: return torch.as_tensor(val, dtype=self.dtype).reshape(self.shape) - def enumerate(self): + def enumerate(self) -> Any: raise NotImplementedError("enumerate cannot be called with continuous specs.") def expand(self, *shape): @@ -2808,7 +2808,7 @@ def __init__( ) self.update_mask(mask) - def enumerate(self): + def enumerate(self) -> torch.Tensor: nvec = self.nvec enum_disc = self.to_categorical_spec().enumerate() enums = torch.cat( @@ -3253,7 +3253,7 @@ def __init__( ) self.update_mask(mask) - def enumerate(self): + def enumerate(self) -> torch.Tensor: arange = torch.arange(self.n, dtype=self.dtype, device=self.device) if self.ndim: arange = arange.view(-1, *(1,) * self.ndim) @@ -3766,7 +3766,7 @@ def __init__( self.update_mask(mask) self.remove_singleton = remove_singleton - def enumerate(self): + def enumerate(self) -> torch.Tensor: if self.mask is not None: raise RuntimeError( "Cannot enumerate a masked TensorSpec. Submit an issue on github if this feature is requested." @@ -4682,7 +4682,7 @@ def clone(self) -> Composite: shape=self.shape, ) - def enumerate(self): + def enumerate(self) -> TensorDictBase: # We are going to use meshgrid to create samples of all the subspecs in here # but first let's get rid of the batch size, we'll put it back later self_without_batch = self @@ -4959,7 +4959,7 @@ def update(self, dict) -> None: self[key] = item return self - def enumerate(self): + def enumerate(self) -> TensorDictBase: dim = self.stack_dim return LazyStackedTensorDict.maybe_dense_stack( [spec.enumerate() for spec in self._specs], dim + 1