From 8eac84ad24bd340f79c72a0c618db67040fffdc8 Mon Sep 17 00:00:00 2001 From: Albert Bou Date: Tue, 29 Oct 2024 00:58:54 -0700 Subject: [PATCH 1/9] [BugFix] Rename RayCollector example file to avoid ImportError (#2525) --- .../distributed/collectors/multi_nodes/{ray.py => ray_collect.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename examples/distributed/collectors/multi_nodes/{ray.py => ray_collect.py} (100%) diff --git a/examples/distributed/collectors/multi_nodes/ray.py b/examples/distributed/collectors/multi_nodes/ray_collect.py similarity index 100% rename from examples/distributed/collectors/multi_nodes/ray.py rename to examples/distributed/collectors/multi_nodes/ray_collect.py From 3e4b2928e3f4910a6b87ad928746aa296276cb65 Mon Sep 17 00:00:00 2001 From: Matteo Bettini <55539777+matteobettini@users.noreply.github.com> Date: Tue, 29 Oct 2024 10:43:51 +0100 Subject: [PATCH 2/9] [Doc] MADDPG bug fix of buffer device and improve explaination (#2519) --- torchrl/data/replay_buffers/storages.py | 2 +- .../sphinx-tutorials/multiagent_competitive_ddpg.py | 9 ++++++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index beab68971b5..21cbfce7b31 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -1106,7 +1106,7 @@ def __init__( if self.device.type != "cpu": raise ValueError( "Memory map device other than CPU isn't supported. To cast your data to the desired device, " - "use `buffer.append_transform(lambda x: x.to(device)` or a similar transform." + "use `buffer.append_transform(lambda x: x.to(device))` or a similar transform." ) self._len = 0 diff --git a/tutorials/sphinx-tutorials/multiagent_competitive_ddpg.py b/tutorials/sphinx-tutorials/multiagent_competitive_ddpg.py index 08b6d83bf5c..0d0c6360958 100644 --- a/tutorials/sphinx-tutorials/multiagent_competitive_ddpg.py +++ b/tutorials/sphinx-tutorials/multiagent_competitive_ddpg.py @@ -655,16 +655,23 @@ # There are many types of buffers, in this tutorial we use a basic buffer to store and sample tensordict # data randomly. # +# This buffer uses :class:`~.data.LazyMemmapStorage`, which stores data on disk. +# This allows to use the disk memory, but can result in slower sampling as it requires data to be cast to the training device. +# To store your buffer on the GPU, you can use :class:`~.data.LazyTensorStorage`, passing the desired device. +# This will result in faster sampling but is subject to the memory constraints of the selected device. +# replay_buffers = {} for group, _agents in env.group_map.items(): replay_buffer = ReplayBuffer( storage=LazyMemmapStorage( - memory_size, device=device + memory_size ), # We will store up to memory_size multi-agent transitions sampler=RandomSampler(), batch_size=train_batch_size, # We will sample batches of this size ) + if device.type != "cpu": + replay_buffer.append_transform(lambda x: x.to(device)) replay_buffers[group] = replay_buffer ###################################################################### From da0bf1897e0725054418617c425bf2b7b49547de Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 29 Oct 2024 11:05:58 +0000 Subject: [PATCH 3/9] [Minor] Fix fbcode imports of mocking classes ghstack-source-id: 74f9f3bedf8f48988a1956084548f6cd2f720934 Pull Request resolved: https://github.com/pytorch/rl/pull/2526 --- test/test_actors.py | 3 +- test/test_collector.py | 59 ++++++++++++++++++-------- test/test_cost.py | 4 +- test/test_distributed.py | 6 ++- test/test_env.py | 75 ++++++++++++++++++++++------------ test/test_exploration.py | 15 ++++--- test/test_helpers.py | 38 ++++++++++------- test/test_modules.py | 3 +- test/test_rb.py | 3 +- test/test_tensordictmodules.py | 7 +++- test/test_transforms.py | 42 ++++++++++++------- 11 files changed, 171 insertions(+), 84 deletions(-) diff --git a/test/test_actors.py b/test/test_actors.py index ac69001db45..c50bf7b9e62 100644 --- a/test/test_actors.py +++ b/test/test_actors.py @@ -8,7 +8,6 @@ import pytest import torch -from mocking_classes import NestedCountingEnv from tensordict import TensorDict from tensordict.nn import CompositeDistribution, TensorDictModule from tensordict.nn.distributions import NormalParamExtractor @@ -33,8 +32,10 @@ if os.getenv("PYTORCH_TEST_FBCODE"): from pytorch.rl.test._utils_internal import get_default_devices + from pytorch.rl.test.mocking_classes import NestedCountingEnv else: from _utils_internal import get_default_devices + from mocking_classes import NestedCountingEnv @pytest.mark.parametrize( diff --git a/test/test_collector.py b/test/test_collector.py index deb06d0b5cc..1309254ce2d 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -14,23 +14,6 @@ import numpy as np import pytest import torch -from mocking_classes import ( - ContinuousActionVecMockEnv, - CountingBatchedEnv, - CountingEnv, - CountingEnvCountPolicy, - DiscreteActionConvMockEnv, - DiscreteActionConvPolicy, - DiscreteActionVecMockEnv, - DiscreteActionVecPolicy, - EnvWithDynamicSpec, - HeterogeneousCountingEnv, - HeterogeneousCountingEnvPolicy, - MockSerialEnv, - MultiKeyCountingEnv, - MultiKeyCountingEnvPolicy, - NestedCountingEnv, -) from packaging import version from tensordict import ( assert_allclose_td, @@ -103,6 +86,23 @@ PONG_VERSIONED, retry, ) + from pytorch.rl.test.mocking_classes import ( + ContinuousActionVecMockEnv, + CountingBatchedEnv, + CountingEnv, + CountingEnvCountPolicy, + DiscreteActionConvMockEnv, + DiscreteActionConvPolicy, + DiscreteActionVecMockEnv, + DiscreteActionVecPolicy, + EnvWithDynamicSpec, + HeterogeneousCountingEnv, + HeterogeneousCountingEnvPolicy, + MockSerialEnv, + MultiKeyCountingEnv, + MultiKeyCountingEnvPolicy, + NestedCountingEnv, + ) else: from _utils_internal import ( CARTPOLE_VERSIONED, @@ -116,6 +116,23 @@ PONG_VERSIONED, retry, ) + from mocking_classes import ( + ContinuousActionVecMockEnv, + CountingBatchedEnv, + CountingEnv, + CountingEnvCountPolicy, + DiscreteActionConvMockEnv, + DiscreteActionConvPolicy, + DiscreteActionVecMockEnv, + DiscreteActionVecPolicy, + EnvWithDynamicSpec, + HeterogeneousCountingEnv, + HeterogeneousCountingEnvPolicy, + MockSerialEnv, + MultiKeyCountingEnv, + MultiKeyCountingEnvPolicy, + NestedCountingEnv, + ) # torch.set_default_dtype(torch.double) IS_WINDOWS = sys.platform == "win32" @@ -1953,7 +1970,13 @@ def test_collector_nested_env_combinations( @pytest.mark.parametrize("batch_size", [(), (5,), (5, 2)]) def test_nested_env_dims(self, batch_size, nested_dim=5, frames_per_batch=20): - from mocking_classes import CountingEnvCountPolicy, NestedCountingEnv + if os.getenv("PYTORCH_TEST_FBCODE"): + from pytorch.rl.test.mocking_classes import ( + CountingEnvCountPolicy, + NestedCountingEnv, + ) + else: + from mocking_classes import CountingEnvCountPolicy, NestedCountingEnv env = NestedCountingEnv(batch_size=batch_size, nested_dim=nested_dim) env_fn = lambda: NestedCountingEnv(batch_size=batch_size, nested_dim=nested_dim) diff --git a/test/test_cost.py b/test/test_cost.py index 85d1b6f7dc9..0066c024776 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -18,8 +18,6 @@ import pytest import torch -from mocking_classes import ContinuousActionConvMockEnv - from packaging import version, version as pack_version from tensordict import assert_allclose_td, TensorDict, TensorDictBase @@ -141,12 +139,14 @@ get_available_devices, get_default_devices, ) + from pytorch.rl.test.mocking_classes import ContinuousActionConvMockEnv else: from _utils_internal import ( # noqa dtype_fixture, get_available_devices, get_default_devices, ) + from mocking_classes import ContinuousActionConvMockEnv _has_functorch = True try: diff --git a/test/test_distributed.py b/test/test_distributed.py index fd369f64962..2529e4e3d7e 100644 --- a/test/test_distributed.py +++ b/test/test_distributed.py @@ -27,7 +27,6 @@ import torch -from mocking_classes import ContinuousActionVecMockEnv, CountingEnv from torch import multiprocessing as mp, nn from torchrl.collectors.collectors import ( @@ -44,6 +43,11 @@ from torchrl.collectors.distributed.ray import DEFAULT_RAY_INIT_CONFIG from torchrl.envs.utils import RandomPolicy +if os.getenv("PYTORCH_TEST_FBCODE"): + from pytorch.rl.test.mocking_classes import ContinuousActionVecMockEnv, CountingEnv +else: + from mocking_classes import ContinuousActionVecMockEnv, CountingEnv + TIMEOUT = 200 if sys.platform.startswith("win"): diff --git a/test/test_env.py b/test/test_env.py index 046fd64ca19..04bf18c7c8c 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -30,6 +30,31 @@ PONG_VERSIONED, rand_reset, ) + from pytorch.rl.test.mocking_classes import ( + ActionObsMergeLinear, + AutoResetHeteroCountingEnv, + AutoResettingCountingEnv, + ContinuousActionConvMockEnv, + ContinuousActionConvMockEnvNumpy, + ContinuousActionVecMockEnv, + CountingBatchedEnv, + CountingEnv, + CountingEnvCountPolicy, + DiscreteActionConvMockEnv, + DiscreteActionConvMockEnvNumpy, + DiscreteActionVecMockEnv, + DummyModelBasedEnvBase, + EnvWithDynamicSpec, + EnvWithMetadata, + HeterogeneousCountingEnv, + HeterogeneousCountingEnvPolicy, + MockBatchedLockedEnv, + MockBatchedUnLockedEnv, + MockSerialEnv, + MultiKeyCountingEnv, + MultiKeyCountingEnvPolicy, + NestedCountingEnv, + ) else: from _utils_internal import ( _make_envs, @@ -42,31 +67,31 @@ PONG_VERSIONED, rand_reset, ) -from mocking_classes import ( - ActionObsMergeLinear, - AutoResetHeteroCountingEnv, - AutoResettingCountingEnv, - ContinuousActionConvMockEnv, - ContinuousActionConvMockEnvNumpy, - ContinuousActionVecMockEnv, - CountingBatchedEnv, - CountingEnv, - CountingEnvCountPolicy, - DiscreteActionConvMockEnv, - DiscreteActionConvMockEnvNumpy, - DiscreteActionVecMockEnv, - DummyModelBasedEnvBase, - EnvWithDynamicSpec, - EnvWithMetadata, - HeterogeneousCountingEnv, - HeterogeneousCountingEnvPolicy, - MockBatchedLockedEnv, - MockBatchedUnLockedEnv, - MockSerialEnv, - MultiKeyCountingEnv, - MultiKeyCountingEnvPolicy, - NestedCountingEnv, -) + from mocking_classes import ( + ActionObsMergeLinear, + AutoResetHeteroCountingEnv, + AutoResettingCountingEnv, + ContinuousActionConvMockEnv, + ContinuousActionConvMockEnvNumpy, + ContinuousActionVecMockEnv, + CountingBatchedEnv, + CountingEnv, + CountingEnvCountPolicy, + DiscreteActionConvMockEnv, + DiscreteActionConvMockEnvNumpy, + DiscreteActionVecMockEnv, + DummyModelBasedEnvBase, + EnvWithDynamicSpec, + EnvWithMetadata, + HeterogeneousCountingEnv, + HeterogeneousCountingEnvPolicy, + MockBatchedLockedEnv, + MockBatchedUnLockedEnv, + MockSerialEnv, + MultiKeyCountingEnv, + MultiKeyCountingEnvPolicy, + NestedCountingEnv, + ) from packaging import version from tensordict import ( assert_allclose_td, diff --git a/test/test_exploration.py b/test/test_exploration.py index b7d0bfd2ffa..b69b743e48c 100644 --- a/test/test_exploration.py +++ b/test/test_exploration.py @@ -8,11 +8,6 @@ import pytest import torch -from mocking_classes import ( - ContinuousActionVecMockEnv, - CountingEnvCountModule, - NestedCountingEnv, -) from scipy.stats import ttest_1samp from tensordict import TensorDict @@ -49,8 +44,18 @@ if os.getenv("PYTORCH_TEST_FBCODE"): from pytorch.rl.test._utils_internal import get_default_devices + from pytorch.rl.test.mocking_classes import ( + ContinuousActionVecMockEnv, + CountingEnvCountModule, + NestedCountingEnv, + ) else: from _utils_internal import get_default_devices + from mocking_classes import ( + ContinuousActionVecMockEnv, + CountingEnvCountModule, + NestedCountingEnv, + ) class TestEGreedy: diff --git a/test/test_helpers.py b/test/test_helpers.py index 46edc93ee2f..cf1160f1bb2 100644 --- a/test/test_helpers.py +++ b/test/test_helpers.py @@ -12,23 +12,9 @@ import pytest import torch -from torchrl._utils import timeit - -try: - from hydra import compose, initialize - from hydra.core.config_store import ConfigStore - _has_hydra = True -except ImportError: - _has_hydra = False -from mocking_classes import ( - ContinuousActionConvMockEnvNumpy, - ContinuousActionVecMockEnv, - DiscreteActionConvMockEnvNumpy, - DiscreteActionVecMockEnv, - MockSerialEnv, -) from packaging import version +from torchrl._utils import timeit from torchrl.data import Bounded, Composite from torchrl.envs.libs.gym import _has_gym from torchrl.envs.transforms import ObservationNorm @@ -51,8 +37,30 @@ if os.getenv("PYTORCH_TEST_FBCODE"): from pytorch.rl.test._utils_internal import generate_seeds, get_default_devices + from pytorch.rl.test.mocking_classes import ( + ContinuousActionConvMockEnvNumpy, + ContinuousActionVecMockEnv, + DiscreteActionConvMockEnvNumpy, + DiscreteActionVecMockEnv, + MockSerialEnv, + ) else: from _utils_internal import generate_seeds, get_default_devices + from mocking_classes import ( + ContinuousActionConvMockEnvNumpy, + ContinuousActionVecMockEnv, + DiscreteActionConvMockEnvNumpy, + DiscreteActionVecMockEnv, + MockSerialEnv, + ) + +try: + from hydra import compose, initialize + from hydra.core.config_store import ConfigStore + + _has_hydra = True +except ImportError: + _has_hydra = False TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version) if TORCH_VERSION < version.parse("1.12.0"): diff --git a/test/test_modules.py b/test/test_modules.py index ddc1e9315e6..cfce54fc1c2 100644 --- a/test/test_modules.py +++ b/test/test_modules.py @@ -14,9 +14,10 @@ if os.getenv("PYTORCH_TEST_FBCODE"): from pytorch.rl.test._utils_internal import get_default_devices, retry + from pytorch.rl.test.mocking_classes import MockBatchedUnLockedEnv else: from _utils_internal import get_default_devices, retry -from mocking_classes import MockBatchedUnLockedEnv + from mocking_classes import MockBatchedUnLockedEnv from packaging import version from tensordict import TensorDict from torch import nn diff --git a/test/test_rb.py b/test/test_rb.py index 1708f4279ab..c14ccb64c04 100644 --- a/test/test_rb.py +++ b/test/test_rb.py @@ -24,6 +24,7 @@ get_default_devices, make_tc, ) + from pytorch.rl.test.mocking_classes import CountingEnv else: from _utils_internal import ( capture_log_records, @@ -31,8 +32,8 @@ get_default_devices, make_tc, ) + from mocking_classes import CountingEnv -from mocking_classes import CountingEnv from packaging import version from packaging.version import parse from tensordict import ( diff --git a/test/test_tensordictmodules.py b/test/test_tensordictmodules.py index ea177cb9f96..05e46ce0ecd 100644 --- a/test/test_tensordictmodules.py +++ b/test/test_tensordictmodules.py @@ -4,10 +4,10 @@ # LICENSE file in the root directory of this source tree. import argparse +import os import pytest import torch -from mocking_classes import CountingEnv, DiscreteActionVecMockEnv from tensordict import LazyStackedTensorDict, pad, TensorDict, unravel_key_list from tensordict.nn import InteractionType, TensorDictModule, TensorDictSequential from torch import nn @@ -52,6 +52,11 @@ from torchrl.modules.utils import get_primers_from_module from torchrl.objectives import DDPGLoss +if os.getenv("PYTORCH_TEST_FBCODE"): + from pytorch.rl.test.mocking_classes import CountingEnv, DiscreteActionVecMockEnv +else: + from mocking_classes import CountingEnv, DiscreteActionVecMockEnv + _has_functorch = False try: try: diff --git a/test/test_transforms.py b/test/test_transforms.py index 84c4b3871fa..b4465aec483 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -32,6 +32,20 @@ rand_reset, retry, ) + from pytorch.rl.test.mocking_classes import ( + ContinuousActionVecMockEnv, + CountingBatchedEnv, + CountingEnv, + CountingEnvCountPolicy, + DiscreteActionConvMockEnv, + DiscreteActionConvMockEnvNumpy, + IncrementingEnv, + MockBatchedLockedEnv, + MockBatchedUnLockedEnv, + MultiKeyCountingEnv, + MultiKeyCountingEnvPolicy, + NestedCountingEnv, + ) else: from _utils_internal import ( # noqa BREAKOUT_VERSIONED, @@ -43,20 +57,20 @@ rand_reset, retry, ) -from mocking_classes import ( - ContinuousActionVecMockEnv, - CountingBatchedEnv, - CountingEnv, - CountingEnvCountPolicy, - DiscreteActionConvMockEnv, - DiscreteActionConvMockEnvNumpy, - IncrementingEnv, - MockBatchedLockedEnv, - MockBatchedUnLockedEnv, - MultiKeyCountingEnv, - MultiKeyCountingEnvPolicy, - NestedCountingEnv, -) + from mocking_classes import ( + ContinuousActionVecMockEnv, + CountingBatchedEnv, + CountingEnv, + CountingEnvCountPolicy, + DiscreteActionConvMockEnv, + DiscreteActionConvMockEnvNumpy, + IncrementingEnv, + MockBatchedLockedEnv, + MockBatchedUnLockedEnv, + MultiKeyCountingEnv, + MultiKeyCountingEnvPolicy, + NestedCountingEnv, + ) from tensordict import TensorDict, TensorDictBase, unravel_key from tensordict.nn import TensorDictSequential from tensordict.utils import _unravel_key_to_tuple, assert_allclose_td From d524d0d6b34f6d1a912dfc0ad0beb4fa1e18993a Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 29 Oct 2024 11:47:36 +0000 Subject: [PATCH 4/9] [Feature] Send info dict to the storage device in RBs ghstack-source-id: 4ed60d649b17f96b49f90d234e679937c60a3c32 Pull Request resolved: https://github.com/pytorch/rl/pull/2527 --- torchrl/data/replay_buffers/replay_buffers.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/torchrl/data/replay_buffers/replay_buffers.py b/torchrl/data/replay_buffers/replay_buffers.py index aa88dd9d186..2672c90092f 100644 --- a/torchrl/data/replay_buffers/replay_buffers.py +++ b/torchrl/data/replay_buffers/replay_buffers.py @@ -712,7 +712,7 @@ def sample(self, batch_size: int | None = None, return_info: bool = False) -> An "for a proper usage of the batch-size arguments." ) if not self._prefetch: - ret = self._sample(batch_size) + result = self._sample(batch_size) else: with self._futures_lock: while ( @@ -722,11 +722,15 @@ def sample(self, batch_size: int | None = None, return_info: bool = False) -> An ) or not len(self._prefetch_queue): fut = self._prefetch_executor.submit(self._sample, batch_size) self._prefetch_queue.append(fut) - ret = self._prefetch_queue.popleft().result() + result = self._prefetch_queue.popleft().result() if return_info: - return ret - return ret[0] + out, info = result + if getattr(self.storage, "device", None) is not None: + device = self.storage.device + info = tree_map(lambda x: x.to(device) if hasattr(x, "to") else x, info) + return out, info + return result[0] def mark_update(self, index: Union[int, torch.Tensor]) -> None: self._sampler.mark_update(index, storage=self._storage) From c851e1698dd4a46fd1429f5afe8057c80072475e Mon Sep 17 00:00:00 2001 From: Faury Louis Date: Tue, 29 Oct 2024 13:41:13 +0100 Subject: [PATCH 5/9] [Feature] Adds ordinal distributions (#2520) Co-authored-by: Louis Faury --- docs/source/reference/modules.rst | 2 + test/test_distributions.py | 122 ++++++++++++++++++++++ torchrl/modules/__init__.py | 2 + torchrl/modules/distributions/__init__.py | 4 + torchrl/modules/distributions/discrete.py | 95 +++++++++++++++-- 5 files changed, 215 insertions(+), 10 deletions(-) diff --git a/docs/source/reference/modules.rst b/docs/source/reference/modules.rst index e1642868228..349d1277c98 100644 --- a/docs/source/reference/modules.rst +++ b/docs/source/reference/modules.rst @@ -553,6 +553,8 @@ Some distributions are typically used in RL scripts. OneHotCategorical MaskedCategorical MaskedOneHotCategorical + Ordinal + OneHotOrdinal Utils ----- diff --git a/test/test_distributions.py b/test/test_distributions.py index 79929135bc8..e283fb9a9b8 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -17,6 +17,8 @@ from torchrl.modules import ( NormalParamWrapper, OneHotCategorical, + OneHotOrdinal, + Ordinal, ReparamGradientStrategy, TanhNormal, TruncatedNormal, @@ -28,6 +30,7 @@ TanhDelta, ) from torchrl.modules.distributions.continuous import SafeTanhTransform +from torchrl.modules.distributions.discrete import _generate_ordinal_logits if os.getenv("PYTORCH_TEST_FBCODE"): from pytorch.rl.test._utils_internal import get_default_devices @@ -677,6 +680,125 @@ def test_reparam(self, grad_method, sparse): assert logits.grad is not None and logits.grad.norm() > 0 +class TestOrdinal: + @pytest.mark.parametrize("dtype", [torch.float, torch.double]) + @pytest.mark.parametrize("device", get_default_devices()) + @pytest.mark.parametrize("logit_shape", [(10,), (1, 1), (10, 10), (5, 10, 20)]) + def test_correct_sampling_shape( + self, logit_shape: tuple[int, ...], dtype: torch.dtype, device: str + ) -> None: + logits = torch.testing.make_tensor(logit_shape, dtype=dtype, device=device) + + sampler = Ordinal(scores=logits) + actions = sampler.sample() # type: ignore[no-untyped-call] + log_probs = sampler.log_prob(actions) # type: ignore[no-untyped-call] + + expected_log_prob_shape = logit_shape[:-1] + expected_action_shape = logit_shape[:-1] + + assert actions.size() == torch.Size(expected_action_shape) + assert log_probs.size() == torch.Size(expected_log_prob_shape) + + @pytest.mark.parametrize("num_categories", [1, 10, 20]) + def test_correct_range(self, num_categories: int) -> None: + seq_size = 10 + batch_size = 100 + logits = torch.ones((batch_size, seq_size, num_categories)) + + sampler = Ordinal(scores=logits) + + actions = sampler.sample() # type: ignore[no-untyped-call] + + assert actions.min() >= 0 + assert actions.max() < num_categories + + def test_bounded_gradients(self) -> None: + logits = torch.tensor( + [[1.0, 0.0, torch.finfo().max], [1.0, 0.0, torch.finfo().min]], + requires_grad=True, + dtype=torch.float32, + ) + + sampler = Ordinal(scores=logits) + + actions = sampler.sample() + log_probs = sampler.log_prob(actions) + + dummy_objective = log_probs.sum() + dummy_objective.backward() + + assert logits.grad is not None + assert not torch.isnan(logits.grad).any() + + def test_generate_ordinal_logits_numerical(self) -> None: + logits = torch.ones((3, 4)) + + ordinal_logits = _generate_ordinal_logits(scores=logits) + + expected_ordinal_logits = torch.tensor( + [ + [-4.2530, -3.2530, -2.2530, -1.2530], + [-4.2530, -3.2530, -2.2530, -1.2530], + [-4.2530, -3.2530, -2.2530, -1.2530], + ] + ) + + torch.testing.assert_close( + ordinal_logits, expected_ordinal_logits, atol=1e-4, rtol=1e-6 + ) + + +class TestOneHotOrdinal: + @pytest.mark.parametrize("dtype", [torch.float, torch.double]) + @pytest.mark.parametrize("device", get_default_devices()) + @pytest.mark.parametrize("logit_shape", [(10,), (10, 10), (5, 10, 20)]) + def test_correct_sampling_shape( + self, logit_shape: tuple[int, ...], dtype: torch.dtype, device: str + ) -> None: + logits = torch.testing.make_tensor(logit_shape, dtype=dtype, device=device) + + sampler = OneHotOrdinal(scores=logits) + actions = sampler.sample() # type: ignore[no-untyped-call] + log_probs = sampler.log_prob(actions) # type: ignore[no-untyped-call] + expected_log_prob_shape = logit_shape[:-1] + + expected_action_shape = logit_shape + + assert actions.size() == torch.Size(expected_action_shape) + assert log_probs.size() == torch.Size(expected_log_prob_shape) + + @pytest.mark.parametrize("num_categories", [2, 10, 20]) + def test_correct_range(self, num_categories: int) -> None: + seq_size = 10 + batch_size = 100 + logits = torch.ones((batch_size, seq_size, num_categories)) + + sampler = OneHotOrdinal(scores=logits) + + actions = sampler.sample() # type: ignore[no-untyped-call] + + assert torch.all(actions.sum(-1)) + assert actions.shape[-1] == num_categories + + def test_bounded_gradients(self) -> None: + logits = torch.tensor( + [[1.0, 0.0, torch.finfo().max], [1.0, 0.0, torch.finfo().min]], + requires_grad=True, + dtype=torch.float32, + ) + + sampler = OneHotOrdinal(scores=logits) + + actions = sampler.sample() + log_probs = sampler.log_prob(actions) + + dummy_objective = log_probs.sum() + dummy_objective.backward() + + assert logits.grad is not None + assert not torch.isnan(logits.grad).any() + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/torchrl/modules/__init__.py b/torchrl/modules/__init__.py index f65461842bb..8523a783676 100644 --- a/torchrl/modules/__init__.py +++ b/torchrl/modules/__init__.py @@ -14,6 +14,8 @@ NormalParamExtractor, NormalParamWrapper, OneHotCategorical, + OneHotOrdinal, + Ordinal, ReparamGradientStrategy, TanhDelta, TanhNormal, diff --git a/torchrl/modules/distributions/__init__.py b/torchrl/modules/distributions/__init__.py index 367765812bb..52f8f302a35 100644 --- a/torchrl/modules/distributions/__init__.py +++ b/torchrl/modules/distributions/__init__.py @@ -17,6 +17,8 @@ MaskedCategorical, MaskedOneHotCategorical, OneHotCategorical, + OneHotOrdinal, + Ordinal, ReparamGradientStrategy, ) @@ -31,5 +33,7 @@ MaskedCategorical, MaskedOneHotCategorical, OneHotCategorical, + Ordinal, + OneHotOrdinal, ) } diff --git a/torchrl/modules/distributions/discrete.py b/torchrl/modules/distributions/discrete.py index d2ffba30686..eb802294a12 100644 --- a/torchrl/modules/distributions/discrete.py +++ b/torchrl/modules/distributions/discrete.py @@ -9,11 +9,9 @@ import torch import torch.distributions as D +import torch.nn.functional as F -__all__ = [ - "OneHotCategorical", - "MaskedCategorical", -] +__all__ = ["OneHotCategorical", "MaskedCategorical", "Ordinal", "OneHotOrdinal"] def _treat_categorical_params( @@ -56,7 +54,7 @@ class ReparamGradientStrategy(Enum): class OneHotCategorical(D.Categorical): """One-hot categorical distribution. - This class behaves excacly as torch.distributions.Categorical except that it reads and produces one-hot encodings + This class behaves exactly as torch.distributions.Categorical except that it reads and produces one-hot encodings of the discrete tensors. Args: @@ -66,7 +64,7 @@ class OneHotCategorical(D.Categorical): reparameterized samples. ``ReparamGradientStrategy.PassThrough`` will compute the sample gradients by using the softmax valued log-probability as a proxy to the - samples gradients. + sample gradients. ``ReparamGradientStrategy.RelaxedOneHot`` will use :class:`torch.distributions.RelaxedOneHot` to sample from the distribution. @@ -81,8 +79,6 @@ class OneHotCategorical(D.Categorical): """ - num_params: int = 1 - def __init__( self, logits: Optional[torch.Tensor] = None, @@ -155,7 +151,7 @@ class MaskedCategorical(D.Categorical): Args: logits (torch.Tensor): event log probabilities (unnormalized) probs (torch.Tensor): event probabilities. If provided, the probabilities - corresponding to to masked items will be zeroed and the probability + corresponding to masked items will be zeroed and the probability re-normalized along its last dimension. Keyword Args: @@ -306,7 +302,7 @@ class MaskedOneHotCategorical(MaskedCategorical): Args: logits (torch.Tensor): event log probabilities (unnormalized) probs (torch.Tensor): event probabilities. If provided, the probabilities - corresponding to to masked items will be zeroed and the probability + corresponding to masked items will be zeroed and the probability re-normalized along its last dimension. Keyword Args: @@ -469,3 +465,82 @@ def rsample(self, sample_shape: Union[torch.Size, Sequence] = None) -> torch.Ten raise ValueError( f"Unknown reparametrization strategy {self.reparam_strategy}." ) + + +class Ordinal(D.Categorical): + """A discrete distribution for learning to sample from finite ordered sets. + + It is defined in contrast with the `Categorical` distribution, which does + not impose any notion of proximity or ordering over its support's atoms. + The `Ordinal` distribution explicitly encodes those concepts, which is + useful for learning discrete sampling from continuous sets. See ยง5 of + `Tang & Agrawal, 2020`_ for details. + + .. note:: + This class is mostly useful when you want to learn a distribution over + a finite set which is obtained by discretising a continuous set. + + Args: + scores (torch.Tensor): a tensor of shape [..., N] where N is the size of the set which supports the distributions. + Typically, the output of a neural network parametrising the distribution. + + Examples: + >>> num_atoms, num_samples = 5, 20 + >>> mean = (num_atoms - 1) / 2 # Target mean for samples, centered around the middle atom + >>> torch.manual_seed(42) + >>> logits = torch.ones((num_atoms), requires_grad=True) + >>> optimizer = torch.optim.Adam([logits], lr=0.1) + >>> + >>> # Perform optimisation loop to minimise deviation from `mean` + >>> for _ in range(20): + >>> sampler = Ordinal(scores=logits) + >>> samples = sampler.sample((num_samples,)) + >>> # Define loss to encourage samples around the mean by penalising deviation from mean + >>> loss = torch.mean((samples - mean) ** 2 * sampler.log_prob(samples)) + >>> loss.backward() + >>> optimizer.step() + >>> optimizer.zero_grad() + >>> + >>> sampler.probs + tensor([0.0308, 0.1586, 0.4727, 0.2260, 0.1120], ...) + >>> # Print histogram to observe sample distribution frequency across 5 bins (0, 1, 2, 3, and 4) + >>> torch.histogram(sampler.sample((1000,)).reshape(-1).float(), bins=num_atoms) + torch.return_types.histogram( + hist=tensor([ 24., 158., 478., 228., 112.]), + bin_edges=tensor([0.0000, 0.8000, 1.6000, 2.4000, 3.2000, 4.0000])) + """ + + def __init__(self, scores: torch.Tensor): + logits = _generate_ordinal_logits(scores) + super().__init__(logits=logits) + + +class OneHotOrdinal(OneHotCategorical): + """The one-hot version of the :class:`~tensordict.nn.distributions.Ordinal` distribution. + + Args: + scores (torch.Tensor): a tensor of shape [..., N] where N is the size of the set which supports the distributions. + Typically, the output of a neural network parametrising the distribution. + """ + + def __init__(self, scores: torch.Tensor): + logits = _generate_ordinal_logits(scores) + super().__init__(logits=logits) + + +def _generate_ordinal_logits(scores: torch.Tensor) -> torch.Tensor: + """Implements Eq. 4 of `Tang & Agrawal, 2020`__.""" + # Assigns Bernoulli-like probabilities for each class in the set + log_probs = F.logsigmoid(scores) + complementary_log_probs = F.logsigmoid(-scores) + + # Total log-probability for being "larger than k" + larger_than_log_probs = log_probs.cumsum(dim=-1) + + # Total log-probability for being "smaller than k" + smaller_than_log_probs = ( + complementary_log_probs.flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1]) + - complementary_log_probs + ) + + return larger_than_log_probs + smaller_than_log_probs From 6799a7f5d111bcab355c44831d5996b2e5517d06 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 25 Oct 2024 17:47:17 -0700 Subject: [PATCH 6/9] [BugFix] Fix pendulum device ghstack-source-id: bcaf20de6e317d4bda0e1511e0b1e46653a6f352 Pull Request resolved: https://github.com/pytorch/rl/pull/2516 --- test/test_env.py | 12 ++++++------ torchrl/envs/custom/pendulum.py | 9 +++++---- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/test/test_env.py b/test/test_env.py index 04bf18c7c8c..1f95a55c2c7 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -3405,16 +3405,16 @@ def test_tictactoe_env_single(self): ) assert r.shape == (5, 100) - def test_pendulum_env(self): - env = PendulumEnv(device=None) - assert env.device is None - env = PendulumEnv(device="cpu") - assert env.device == torch.device("cpu") + @pytest.mark.parametrize("device", [None, *get_default_devices()]) + def test_pendulum_env(self, device): + env = PendulumEnv(device=device) + assert env.device == device check_env_specs(env) + for _ in range(10): r = env.rollout(10) assert r.shape == torch.Size((10,)) - r = env.rollout(10, tensordict=TensorDict(batch_size=[5])) + r = env.rollout(10, tensordict=TensorDict(batch_size=[5], device=device)) assert r.shape == torch.Size((5, 10)) diff --git a/torchrl/envs/custom/pendulum.py b/torchrl/envs/custom/pendulum.py index e2007227127..579faecc3c6 100644 --- a/torchrl/envs/custom/pendulum.py +++ b/torchrl/envs/custom/pendulum.py @@ -220,7 +220,7 @@ class PendulumEnv(EnvBase): def __init__(self, td_params=None, seed=None, device=None): if td_params is None: - td_params = self.gen_params() + td_params = self.gen_params(device=self.device) super().__init__(device=device) self._make_spec(td_params) @@ -273,7 +273,7 @@ def _reset(self, tensordict): # if no ``tensordict`` is passed, we generate a single set of hyperparameters # Otherwise, we assume that the input ``tensordict`` contains all the relevant # parameters to get started. - tensordict = self.gen_params(batch_size=batch_size) + tensordict = self.gen_params(batch_size=batch_size, device=self.device) high_th = torch.tensor(self.DEFAULT_X, device=self.device) high_thdot = torch.tensor(self.DEFAULT_Y, device=self.device) @@ -355,12 +355,12 @@ def make_composite_from_td(td): return composite def _set_seed(self, seed: int): - rng = torch.Generator() + rng = torch.Generator(device=self.device) rng.manual_seed(seed) self.rng = rng @staticmethod - def gen_params(g=10.0, batch_size=None) -> TensorDictBase: + def gen_params(g=10.0, batch_size=None, device=None) -> TensorDictBase: """Returns a ``tensordict`` containing the physical parameters such as gravitational force and torque or speed limits.""" if batch_size is None: batch_size = [] @@ -379,6 +379,7 @@ def gen_params(g=10.0, batch_size=None) -> TensorDictBase: ) }, [], + device=device, ) if batch_size: td = td.expand(batch_size).contiguous() From edbf3dee358b828a9bac1387b46d7ce740f467b1 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 31 Oct 2024 10:18:35 +0000 Subject: [PATCH 7/9] [Doc] Fix modules doc (#2531) --- docs/source/reference/modules.rst | 211 +++++------------- .../tensordict_module/probabilistic.py | 2 +- 2 files changed, 53 insertions(+), 160 deletions(-) diff --git a/docs/source/reference/modules.rst b/docs/source/reference/modules.rst index 349d1277c98..ee78c68835f 100644 --- a/docs/source/reference/modules.rst +++ b/docs/source/reference/modules.rst @@ -92,8 +92,7 @@ Some algorithms such as PPO require a probabilistic policy to be implemented. In TorchRL, these policies take the form of a model, followed by a distribution constructor. - .. note:: - The choice of a probabilistic or regular actor class depends on the algorithm + .. note:: The choice of a probabilistic or regular actor class depends on the algorithm that is being implemented. On-policy algorithms usually require a probabilistic actor, off-policy usually have a deterministic actor with an extra exploration strategy. There are, however, many exceptions to this rule. @@ -103,8 +102,12 @@ and outputs the parameters of a distribution, while the distribution constructor reads these parameters and gets a random sample from the distribution and/or provides a :class:`torch.distributions.Distribution` object. - >>> from tensordict.nn import NormalParamExtractor, TensorDictSequential + >>> from tensordict.nn import NormalParamExtractor, TensorDictSequential, TensorDictModule + >>> from torchrl.modules import SafeProbabilisticModule + >>> from torchrl.envs import GymEnv >>> from torch.distributions import Normal + >>> from torch import nn + >>> >>> env = GymEnv("Pendulum-v1") >>> action_spec = env.action_spec >>> model = nn.Sequential(nn.LazyLinear(action_spec.shape[-1] * 2), NormalParamExtractor()) @@ -125,6 +128,7 @@ provides a :class:`torch.distributions.Distribution` object. To facilitate the construction of probabilistic policies, we provide a dedicated :class:`~torchrl.modules.tensordict_module.ProbabilisticActor`: + >>> from torchrl.modules import ProbabilisticActor >>> policy = ProbabilisticActor( ... model, ... in_keys=["loc", "scale"], @@ -154,69 +158,31 @@ of this action. Q-Value actors ~~~~~~~~~~~~~~ -Q-Value actors are a special type of policy that does not directly predict an action -from an observation, but picks the action that maximised the value (or *quality*) -of a (s,a) -> v map. This map can be a table or a function. -For discrete action spaces with continuous (or near-continuous such as pixels) -states, it is customary to use a non-linear model such as a neural network for -the map. -The semantic of the Q-Value network is hopefully quite simple: we just need to -feed a tensor-to-tensor map that given a certain state (the input tensor), -outputs a list of action values to choose from. The wrapper will write the -resulting action in the input tensordict along with the list of action values. +Q-Value actors are a type of policy that selects actions based on the maximum value +(or "quality") of a state-action pair. This value can be represented as a table or a +function. For discrete action spaces with continuous states, it's common to use a non-linear +model like a neural network to represent this function. - >>> import torch - >>> from tensordict import TensorDict - >>> from tensordict.nn.functional_modules import make_functional - >>> from torch import nn - >>> from torchrl.data import OneHot - >>> from torchrl.modules.tensordict_module.actors import QValueActor - >>> td = TensorDict({'observation': torch.randn(5, 3)}, [5]) - >>> # we have 4 actions to choose from - >>> action_spec = OneHot(4) - >>> # the model reads a state of dimension 3 and outputs 4 values, one for each action available - >>> module = nn.Linear(3, 4) - >>> qvalue_actor = QValueActor(module=module, spec=action_spec) - >>> qvalue_actor(td) - >>> print(td) - TensorDict( - fields={ - action: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.int64, is_shared=False), - action_value: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False), - chosen_action_value: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.float32, is_shared=False), - observation: Tensor(shape=torch.Size([5, 3]), device=cpu, dtype=torch.float32, is_shared=False)}, - batch_size=torch.Size([5]), - device=None, - is_shared=False) +QValueActor +^^^^^^^^^^^ -Distributional Q-learning is slightly different: in this case, the value network -does not output a scalar value for each state-action value. -Instead, the value space is divided in a an arbitrary number of "bins". The -value network outputs a probability that the state-action value belongs to one bin -or another. -Hence, for a state space of dimension M, an action space of dimension N and a number of bins B, -the value network encodes a -of a (s,a) -> v map. This map can be a table or a function. -For discrete action spaces with continuous (or near-continuous such as pixels) -states, it is customary to use a non-linear model such as a neural network for -the map. -The semantic of the Q-Value network is hopefully quite simple: we just need to -feed a tensor-to-tensor map that given a certain state (the input tensor), -outputs a list of action values to choose from. The wrapper will write the -resulting action in the input tensordict along with the list of action values. +The :class:`~torchrl.modules.QValueActor` class takes in a module and an action +specification, and outputs the selected action and its corresponding value. >>> import torch >>> from tensordict import TensorDict - >>> from tensordict.nn.functional_modules import make_functional >>> from torch import nn >>> from torchrl.data import OneHot >>> from torchrl.modules.tensordict_module.actors import QValueActor + >>> # Create a tensor dict with an observation >>> td = TensorDict({'observation': torch.randn(5, 3)}, [5]) - >>> # we have 4 actions to choose from + >>> # Define the action space >>> action_spec = OneHot(4) - >>> # the model reads a state of dimension 3 and outputs 4 values, one for each action available + >>> # Create a linear module to output action values >>> module = nn.Linear(3, 4) + >>> # Create a QValueActor instance >>> qvalue_actor = QValueActor(module=module, spec=action_spec) + >>> # Run the actor on the tensor dict >>> qvalue_actor(td) >>> print(td) TensorDict( @@ -229,122 +195,48 @@ resulting action in the input tensordict along with the list of action values. device=None, is_shared=False) -Distributional Q-learning is slightly different: in this case, the value network -does not output a scalar value for each state-action value. -Instead, the value space is divided in a an arbitrary number of "bins". The -value network outputs a probability that the state-action value belongs to one bin -or another. -Hence, for a state space of dimension M, an action space of dimension N and a number of bins B, -the value network encodes a -of a (s,a) -> v map. This map can be a table or a function. -For discrete action spaces with continuous (or near-continuous such as pixels) -states, it is customary to use a non-linear model such as a neural network for -the map. -The semantic of the Q-Value network is hopefully quite simple: we just need to -feed a tensor-to-tensor map that given a certain state (the input tensor), -outputs a list of action values to choose from. The wrapper will write the -resulting action in the input tensordict along with the list of action values. +This will output a tensor dict with the selected action and its corresponding value. + +Distributional Q-Learning +^^^^^^^^^^^^^^^^^^^^^^^^^ + +Distributional Q-learning is a variant of Q-learning that represents the value function as a +probability distribution over possible values, rather than a single scalar value. +This allows the agent to learn about the uncertainty in the environment and make more informed +decisions. +In TorchRL, distributional Q-learning is implemented using the :class:`~torchrl.modules.DistributionalQValueActor` +class. This class takes in a module, an action specification, and a support vector, and outputs the selected +action and its corresponding value distribution. + >>> import torch >>> from tensordict import TensorDict - >>> from tensordict.nn.functional_modules import make_functional >>> from torch import nn >>> from torchrl.data import OneHot - >>> from torchrl.modules.tensordict_module.actors import QValueActor - >>> td = TensorDict({'observation': torch.randn(5, 3)}, [5]) - >>> # we have 4 actions to choose from + >>> from torchrl.modules import DistributionalQValueActor, MLP + >>> # Create a tensor dict with an observation + >>> td = TensorDict({'observation': torch.randn(5, 4)}, [5]) + >>> # Define the action space >>> action_spec = OneHot(4) - >>> # the model reads a state of dimension 3 and outputs 4 values, one for each action available - >>> module = nn.Linear(3, 4) - >>> qvalue_actor = QValueActor(module=module, spec=action_spec) - >>> qvalue_actor(td) + >>> # Define the number of bins for the value distribution + >>> nbins = 3 + >>> # Create an MLP module to output logits for the value distribution + >>> module = MLP(out_features=(nbins, 4), depth=2) + >>> # Create a DistributionalQValueActor instance + >>> qvalue_actor = DistributionalQValueActor(module=module, spec=action_spec, support=torch.arange(nbins)) + >>> # Run the actor on the tensor dict + >>> td = qvalue_actor(td) >>> print(td) TensorDict( fields={ action: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.int64, is_shared=False), - action_value: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False), - chosen_action_value: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.float32, is_shared=False), - observation: Tensor(shape=torch.Size([5, 3]), device=cpu, dtype=torch.float32, is_shared=False)}, + action_value: Tensor(shape=torch.Size([5, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False), + observation: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([5]), device=None, is_shared=False) -Distributional Q-learning is slightly different: in this case, the value network -does not output a scalar value for each state-action value. -Instead, the value space is divided in a an arbitrary number of "bins". The -value network outputs a probability that the state-action value belongs to one bin -or another. -Hence, for a state space of dimension M, an action space of dimension N and a number of bins B, -the value network encodes a :math:`\mathbb{R}^{M} \rightarrow \mathbb{R}^{N \times B}` -map. The following example shows how this works in TorchRL with the :class:`~torchrl.modules.tensordict_module.DistributionalQValueActor` -class: - - >>> import torch - >>> from tensordict import TensorDict - >>> from torch import nn - >>> from torchrl.data import OneHot - >>> from torchrl.modules import DistributionalQValueActor, MLP - >>> td = TensorDict({'observation': torch.randn(5, 4)}, [5]) - >>> nbins = 3 - >>> # our model reads the observation and outputs a stack of 4 logits (one for each action) of size nbins=3 - >>> module = MLP(out_features=(nbins, 4), depth=2) - >>> action_spec = OneHot(4) - >>> qvalue_actor = DistributionalQValueActor(module=module, spec=action_spec, support=torch.arange(nbins)) - >>> td = qvalue_actor(td) - >>> print(td) - TensorDict( - fields={ - action: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.int64, is_shared=False), - action_value: Tensor(shape=torch.Size([5, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False), - observation: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False)}, - batch_size=torch.Size([5]), - device=None, - is_shared=False) - - >>> import torch - >>> from tensordict import TensorDict - >>> from torch import nn - >>> from torchrl.data import OneHot - >>> from torchrl.modules import DistributionalQValueActor, MLP - >>> td = TensorDict({'observation': torch.randn(5, 4)}, [5]) - >>> nbins = 3 - >>> # our model reads the observation and outputs a stack of 4 logits (one for each action) of size nbins=3 - >>> module = MLP(out_features=(nbins, 4), depth=2) - >>> action_spec = OneHot(4) - >>> qvalue_actor = DistributionalQValueActor(module=module, spec=action_spec, support=torch.arange(nbins)) - >>> td = qvalue_actor(td) - >>> print(td) - TensorDict( - fields={ - action: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.int64, is_shared=False), - action_value: Tensor(shape=torch.Size([5, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False), - observation: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False)}, - batch_size=torch.Size([5]), - device=None, - is_shared=False) - - >>> import torch - >>> from tensordict import TensorDict - >>> from torch import nn - >>> from torchrl.data import OneHot - >>> from torchrl.modules import DistributionalQValueActor, MLP - >>> td = TensorDict({'observation': torch.randn(5, 4)}, [5]) - >>> nbins = 3 - >>> # our model reads the observation and outputs a stack of 4 logits (one for each action) of size nbins=3 - >>> module = MLP(out_features=(nbins, 4), depth=2) - >>> action_spec = OneHot(4) - >>> qvalue_actor = DistributionalQValueActor(module=module, spec=action_spec, support=torch.arange(nbins)) - >>> td = qvalue_actor(td) - >>> print(td) - TensorDict( - fields={ - action: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.int64, is_shared=False), - action_value: Tensor(shape=torch.Size([5, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False), - observation: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False)}, - batch_size=torch.Size([5]), - device=None, - is_shared=False) - +This will output a tensor dict with the selected action and its corresponding value distribution. .. currentmodule:: torchrl.modules.tensordict_module @@ -403,11 +295,10 @@ without shared parameters. It is mainly intended as a replacement for Domain-specific TensorDict modules ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. currentmodule:: torchrl.modules.tensordict_module These modules include dedicated solutions for MBRL or RLHF pipelines. -.. currentmodule:: torchrl.modules.tensordict_module - .. autosummary:: :toctree: generated/ :template: rl_template_noinherit.rst @@ -558,9 +449,11 @@ Some distributions are typically used in RL scripts. Utils ----- - .. currentmodule:: torchrl.modules.utils +The module utils include functionals used to do some custom mappings as well as a tool to +build :class:`~torchrl.envs.TensorDictPrimer` instances from a given module. + .. autosummary:: :toctree: generated/ :template: rl_template_noinherit.rst diff --git a/torchrl/modules/tensordict_module/probabilistic.py b/torchrl/modules/tensordict_module/probabilistic.py index 483d9b90eea..8bd5143d20f 100644 --- a/torchrl/modules/tensordict_module/probabilistic.py +++ b/torchrl/modules/tensordict_module/probabilistic.py @@ -128,7 +128,7 @@ def __init__( if spec is not None and not isinstance(spec, TensorSpec): raise TypeError("spec must be a TensorSpec subclass") elif spec is not None and not isinstance(spec, Composite): - if len(self.out_keys) > 1: + if len(self.out_keys) - return_log_prob > 1: raise RuntimeError( f"got more than one out_key for the SafeModule: {self.out_keys},\nbut only one spec. " "Consider using a Composite object or no spec at all." From 05aeb897512c459c9a7f654ce32ca94c4d8e54b0 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 4 Nov 2024 12:52:04 +0000 Subject: [PATCH 8/9] [Feature] TrajCounter transform ghstack-source-id: 62a3091e5c9072f26266143319f30de1729c0d4e Pull Request resolved: https://github.com/pytorch/rl/pull/2532 --- docs/source/reference/envs.rst | 1 + test/test_transforms.py | 210 +++++++++++++++++++++++++ torchrl/envs/__init__.py | 1 + torchrl/envs/batched_envs.py | 2 +- torchrl/envs/env_creator.py | 75 ++++++++- torchrl/envs/transforms/__init__.py | 1 + torchrl/envs/transforms/transforms.py | 212 +++++++++++++++++++++++++- 7 files changed, 494 insertions(+), 8 deletions(-) diff --git a/docs/source/reference/envs.rst b/docs/source/reference/envs.rst index 960daf0fb12..4519900ae8b 100644 --- a/docs/source/reference/envs.rst +++ b/docs/source/reference/envs.rst @@ -845,6 +845,7 @@ to be able to create this other composition: TensorDictPrimer TimeMaxPool ToTensorImage + TrajCounter UnsqueezeTransform VC1Transform VIPRewardTransform diff --git a/test/test_transforms.py b/test/test_transforms.py index b4465aec483..59612d2bd65 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -21,6 +21,8 @@ import tensordict.tensordict import torch +from torchrl.collectors import MultiSyncDataCollector + if os.getenv("PYTORCH_TEST_FBCODE"): from pytorch.rl.test._utils_internal import ( # noqa BREAKOUT_VERSIONED, @@ -135,6 +137,7 @@ TensorDictPrimer, TimeMaxPool, ToTensorImage, + TrajCounter, TransformedEnv, UnsqueezeTransform, VC1Transform, @@ -1926,6 +1929,213 @@ def test_stepcounter_ignore(self): assert env.transform.step_count_keys[0] == ("data", "step_count") +class TestTrajCounter(TransformBase): + def test_single_trans_env_check(self): + torch.manual_seed(0) + env = TransformedEnv(CountingEnv(max_steps=4), TrajCounter()) + env.transform.transform_observation_spec(env.base_env.observation_spec) + check_env_specs(env) + + @pytest.mark.parametrize("predefined", [True, False]) + def test_parallel_trans_env_check(self, predefined): + if predefined: + t = TrajCounter() + else: + t = None + + def make_env(max_steps=4, t=t): + if t is None: + t = TrajCounter() + env = TransformedEnv(CountingEnv(max_steps=max_steps), t.clone()) + env.transform.transform_observation_spec(env.base_env.observation_spec) + return env + + if predefined: + penv = ParallelEnv( + 2, + [EnvCreator(make_env, max_steps=4), EnvCreator(make_env, max_steps=5)], + mp_start_method="spawn", + ) + else: + make_env_c0 = EnvCreator(make_env) + make_env_c1 = make_env_c0.make_variant(max_steps=5) + penv = ParallelEnv( + 2, + [make_env_c0, make_env_c1], + mp_start_method="spawn", + ) + + r = penv.rollout(100, break_when_any_done=False) + s0 = set(r[0]["traj_count"].squeeze().tolist()) + s1 = set(r[1]["traj_count"].squeeze().tolist()) + assert len(s1.intersection(s0)) == 0 + + @pytest.mark.parametrize("predefined", [True, False]) + def test_serial_trans_env_check(self, predefined): + if predefined: + t = TrajCounter() + else: + t = None + + def make_env(max_steps=4, t=t): + if t is None: + t = TrajCounter() + else: + t = t.clone() + env = TransformedEnv(CountingEnv(max_steps=max_steps), t) + env.transform.transform_observation_spec(env.base_env.observation_spec) + return env + + if predefined: + penv = SerialEnv( + 2, + [EnvCreator(make_env, max_steps=4), EnvCreator(make_env, max_steps=5)], + ) + else: + make_env_c0 = EnvCreator(make_env) + make_env_c1 = make_env_c0.make_variant(max_steps=5) + penv = SerialEnv( + 2, + [make_env_c0, make_env_c1], + ) + + r = penv.rollout(100, break_when_any_done=False) + s0 = set(r[0]["traj_count"].squeeze().tolist()) + s1 = set(r[1]["traj_count"].squeeze().tolist()) + assert len(s1.intersection(s0)) == 0 + + def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): + env = TransformedEnv( + maybe_fork_ParallelEnv( + 2, [lambda: CountingEnv(max_steps=4), lambda: CountingEnv(max_steps=5)] + ), + TrajCounter(), + ) + env.transform.transform_observation_spec(env.base_env.observation_spec) + r = env.rollout( + 100, + lambda td: td.set("action", torch.ones(env.shape + (1,))), + break_when_any_done=False, + ) + check_env_specs(env) + assert r["traj_count"].max() == 36 + + def test_trans_serial_env_check(self): + env = TransformedEnv( + SerialEnv( + 2, [lambda: CountingEnv(max_steps=4), lambda: CountingEnv(max_steps=5)] + ), + TrajCounter(), + ) + env.transform.transform_observation_spec(env.base_env.observation_spec) + r = env.rollout( + 100, + lambda td: td.set("action", torch.ones(env.shape + (1,))), + break_when_any_done=False, + ) + check_env_specs(env) + assert r["traj_count"].max() == 36 + + def test_transform_env(self): + torch.manual_seed(0) + env = TransformedEnv(CountingEnv(max_steps=4), TrajCounter()) + env.transform.transform_observation_spec(env.base_env.observation_spec) + r = env.rollout(100, lambda td: td.set("action", 1), break_when_any_done=False) + assert r["traj_count"].max() == 19 + + def test_nested(self): + torch.manual_seed(0) + env = TransformedEnv( + CountingEnv(max_steps=4), + Compose( + RenameTransform("done", ("nested", "done"), create_copy=True), + TrajCounter(out_key=(("nested"), (("traj_count",),))), + ), + ) + env.transform.transform_observation_spec(env.base_env.observation_spec) + r = env.rollout(100, lambda td: td.set("action", 1), break_when_any_done=False) + assert r["nested", "traj_count"].max() == 19 + + @pytest.mark.parametrize("rbclass", [ReplayBuffer, TensorDictReplayBuffer]) + def test_transform_rb(self, rbclass): + t = TrajCounter() + rb = rbclass(storage=LazyTensorStorage(20)) + rb.append_transform(t) + td = ( + TensorDict( + {("next", "observation"): torch.randn(3), "action": torch.randn(2)}, [] + ) + .expand(10) + .contiguous() + ) + rb.extend(td) + with pytest.raises( + RuntimeError, + match="TrajCounter can only be called within an environment step or reset", + ): + td = rb.sample(10) + + def test_collector_match(self): + # The counter in the collector should match the one from the transform + t = TrajCounter() + + def make_env(max_steps=4): + env = TransformedEnv(CountingEnv(max_steps=max_steps), t.clone()) + env.transform.transform_observation_spec(env.base_env.observation_spec) + return env + + collector = MultiSyncDataCollector( + [EnvCreator(make_env, max_steps=5), EnvCreator(make_env, max_steps=4)], + total_frames=99, + frames_per_batch=8, + ) + for d in collector: + # The env has one more traj because the collector calls reset during init + assert d["collector", "traj_ids"].max() == d["next", "traj_count"].max() - 1 + assert d["traj_count"].max() > 0 + + def test_transform_compose(self): + t = TrajCounter() + t = nn.Sequential(t) + td = ( + TensorDict( + {("next", "observation"): torch.randn(3), "action": torch.randn(2)}, [] + ) + .expand(10) + .contiguous() + ) + + with pytest.raises( + RuntimeError, + match="TrajCounter can only be called within an environment step or reset", + ): + td = t(td) + + def test_transform_inverse(self): + pytest.skip("No inverse transform for TrajCounter") + + def test_transform_model(self): + t = TrajCounter() + td = ( + TensorDict( + {("next", "observation"): torch.randn(3), "action": torch.randn(2)}, [] + ) + .expand(10) + .contiguous() + ) + + with pytest.raises( + RuntimeError, + match="TrajCounter can only be called within an environment step or reset", + ): + td = t(td) + + @pytest.mark.parametrize("device", get_default_devices()) + @pytest.mark.parametrize("batch", [[], [4], [6, 4]]) + def test_transform_no_env(self, device, batch): + pytest.skip("TrajCounter cannot be called without env") + + class TestCatTensors(TransformBase): @pytest.mark.parametrize("append", [True, False]) def test_cattensors_empty(self, append): diff --git a/torchrl/envs/__init__.py b/torchrl/envs/__init__.py index 047550fa9d7..56f7a5a3332 100644 --- a/torchrl/envs/__init__.py +++ b/torchrl/envs/__init__.py @@ -92,6 +92,7 @@ TensorDictPrimer, TimeMaxPool, ToTensorImage, + TrajCounter, Transform, TransformedEnv, UnsqueezeTransform, diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 02c7f5893dc..9e59e0f69d6 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -1408,7 +1408,7 @@ def look_for_cuda(tensor, has_cuda=has_cuda): # No certainty which module multiprocessing_context is parent_pipe, child_pipe = ctx.Pipe() env_fun = self.create_env_fn[idx] - if not isinstance(env_fun, EnvCreator): + if not isinstance(env_fun, (EnvCreator, CloudpickleWrapper)): env_fun = CloudpickleWrapper(env_fun) kwargs[idx].update( { diff --git a/torchrl/envs/env_creator.py b/torchrl/envs/env_creator.py index f090289214d..f4cb8e263a1 100644 --- a/torchrl/envs/env_creator.py +++ b/torchrl/envs/env_creator.py @@ -6,6 +6,7 @@ from __future__ import annotations from collections import OrderedDict +from multiprocessing.sharedctypes import Synchronized from typing import Callable, Dict, Optional, Union import torch @@ -33,6 +34,8 @@ class EnvCreator: create_env_kwargs (dict, optional): the kwargs of the env creator. share_memory (bool, optional): if False, the resulting tensordict from the environment won't be placed in shared memory. + **kwargs: additional keyword arguments to be passed to the environment + during construction. Examples: >>> # We create the same environment on 2 processes using VecNorm @@ -79,20 +82,38 @@ def __init__( create_env_fn: Callable[..., EnvBase], create_env_kwargs: Optional[Dict] = None, share_memory: bool = True, + **kwargs, ) -> None: - if not isinstance(create_env_fn, EnvCreator): + if not isinstance(create_env_fn, (EnvCreator, CloudpickleWrapper)): self.create_env_fn = CloudpickleWrapper(create_env_fn) else: self.create_env_fn = create_env_fn - self.create_env_kwargs = ( - create_env_kwargs if isinstance(create_env_kwargs, dict) else {} - ) + self.create_env_kwargs = kwargs + if isinstance(create_env_kwargs, dict): + self.create_env_kwargs.update(create_env_kwargs) self.initialized = False self._meta_data = None self._share_memory = share_memory self.init_() + def make_variant(self, **kwargs) -> EnvCreator: + """Creates a variant of the EnvCreator, pointing to the same underlying metadata but with different keyword arguments during construction. + + This can be useful with transforms that share a state, like :class:`~torchrl.envs.TrajCounter`. + + Examples: + >>> from torchrl.envs import GymEnv + >>> env_creator_pendulum = EnvCreator(GymEnv, env_name="Pendulum-v1") + >>> env_creator_cartpole = env_creator_pendulum(env_name="CartPole-v1") + + """ + # Copy self + out = type(self).__new__(type(self)) + out.__dict__.update(self.__dict__) + out.create_env_kwargs.update(kwargs) + return out + def share_memory(self, state_dict: OrderedDict) -> None: for key, item in list(state_dict.items()): if isinstance(item, (TensorDictBase,)): @@ -101,7 +122,7 @@ def share_memory(self, state_dict: OrderedDict) -> None: else: torchrl_logger.info( f"{self.env_type}: {item} is already shared" - ) # , deleting key') + ) # , deleting key'val) del state_dict[key] elif isinstance(item, OrderedDict): self.share_memory(item) @@ -120,12 +141,43 @@ def meta_data(self) -> EnvMetaData: def meta_data(self, value: EnvMetaData): self._meta_data = value + @staticmethod + def _is_mp_value(val): + + return isinstance(val, (Synchronized,)) and hasattr(val, "_obj") + + @classmethod + def _find_mp_values(cls, env_or_transform, values, prefix=()): + from torchrl.envs.transforms.transforms import Compose, TransformedEnv + + if isinstance(env_or_transform, EnvBase) and isinstance( + env_or_transform, TransformedEnv + ): + cls._find_mp_values( + env_or_transform.transform, + values=values, + prefix=prefix + ("transform",), + ) + cls._find_mp_values( + env_or_transform.base_env, values=values, prefix=prefix + ("base_env",) + ) + elif isinstance(env_or_transform, Compose): + for i, t in enumerate(env_or_transform.transforms): + cls._find_mp_values(t, values=values, prefix=prefix + (i,)) + for k, v in env_or_transform.__dict__.items(): + if cls._is_mp_value(v): + values.append((prefix + (k,), v)) + return values + def init_(self) -> EnvCreator: shadow_env = self.create_env_fn(**self.create_env_kwargs) tensordict = shadow_env.reset() shadow_env.rand_step(tensordict) self.env_type = type(shadow_env) self._transform_state_dict = shadow_env.state_dict() + # Extract any mp.Value object from the env + self._mp_values = self._find_mp_values(shadow_env, values=[]) + if self._share_memory: self.share_memory(self._transform_state_dict) self.initialized = True @@ -134,11 +186,24 @@ def init_(self) -> EnvCreator: del shadow_env return self + @classmethod + def _set_mp_value(cls, env, key, value): + if len(key) > 1: + if isinstance(key[0], int): + return cls._set_mp_value(env[key[0]], key[1:], value) + else: + return cls._set_mp_value(getattr(env, key[0]), key[1:], value) + else: + setattr(env, key[0], value) + def __call__(self, **kwargs) -> EnvBase: if not self.initialized: raise RuntimeError("EnvCreator must be initialized before being called.") kwargs.update(self.create_env_kwargs) # create_env_kwargs precedes env = self.create_env_fn(**kwargs) + if self._mp_values: + for k, v in self._mp_values: + self._set_mp_value(env, k, v) env.load_state_dict(self._transform_state_dict, strict=False) return env diff --git a/torchrl/envs/transforms/__init__.py b/torchrl/envs/transforms/__init__.py index 64a25b94e37..bccbd9a4543 100644 --- a/torchrl/envs/transforms/__init__.py +++ b/torchrl/envs/transforms/__init__.py @@ -53,6 +53,7 @@ TensorDictPrimer, TimeMaxPool, ToTensorImage, + TrajCounter, Transform, TransformedEnv, UnsqueezeTransform, diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index b70e05ca431..600e03775f7 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -18,10 +18,12 @@ Callable, Dict, List, + Mapping, Optional, OrderedDict, Sequence, Tuple, + TypeVar, Union, ) @@ -81,6 +83,8 @@ FORWARD_NOT_IMPLEMENTED = "class {} cannot be executed without a parent environment." +T = TypeVar("T", bound="Transform") + def _apply_to_composite(function): @wraps(function) @@ -458,7 +462,7 @@ def reset_parent(self) -> None: self.__dict__["_container"] = None self.__dict__["_parent"] = None - def clone(self): + def clone(self) -> T: self_copy = copy(self) state = copy(self.__dict__) state["_container"] = None @@ -1221,7 +1225,7 @@ def reset_parent(self): t.reset_parent() super().reset_parent() - def clone(self): + def clone(self) -> T: transforms = [] for t in self.transforms: transforms.append(t.clone()) @@ -8678,3 +8682,207 @@ def _inv_call(self, tensordict): if self.sampling == self.SamplingStrategy.RANDOM: action = action + self.jitters * torch.rand_like(self.jitters) return tensordict.set(self.in_keys_inv[0], action) + + +class TrajCounter(Transform): + """Global trajectory counter transform. + + TrajCounter can be used to count the number of trajectories (i.e., the number of times `reset` is called) in any + TorchRL environment. + This transform will work within a single node across multiple processes (see note below). + A single transform can only count the trajectories associated with a single done state, but nested done states are + accepted as long as their prefix matches the prefix of the counter key. + + Args: + out_key (NestedKey, optional): The entry name of the trajectory counter. Defaults to ``"traj_count"``. + + Examples: + >>> from torchrl.envs import GymEnv, StepCounter, TrajCounter + >>> env = GymEnv("Pendulum-v1").append_transform(StepCounter(6)) + >>> env = env.append_transform(TrajCounter()) + >>> r = env.rollout(18, break_when_any_done=False) # 18 // 6 = 3 trajectories + >>> r["next", "traj_count"] + tensor([[0], + [0], + [0], + [0], + [0], + [0], + [1], + [1], + [1], + [1], + [1], + [1], + [2], + [2], + [2], + [2], + [2], + [2]]) + + .. note:: + Sharing a trajectory counter among workers can be done in multiple ways, but it will usually involve wrapping the environment in a :class:`~torchrl.envs.EnvCreator`. Not doing so may result in an error during serialization of the transform. The counter will be shared among the workers, meaning that at any point in time, it is guaranteed that there will not be two environments that will share the same trajectory count (and each (step-count, traj-count) pair will be unique). + Here are examples of valid ways of sharing a ``TrajCounter`` object between processes: + + >>> # Option 1: Create the trajectory counter outside the environment. + >>> # This requires the counter to be cloned within the transformed env, as a single transform object cannot have two parents. + >>> t = TrajCounter() + >>> def make_env(max_steps=4, t=t): + ... # See CountingEnv in torchrl.test.mocking_classes + ... env = TransformedEnv(CountingEnv(max_steps=max_steps), t.clone()) + ... env.transform.transform_observation_spec(env.base_env.observation_spec) + ... return env + >>> penv = ParallelEnv( + ... 2, + ... [EnvCreator(make_env, max_steps=4), EnvCreator(make_env, max_steps=5)], + ... mp_start_method="spawn", + ... ) + >>> # Option 2: Create the transform within the constructor. + >>> # In this scenario, we still need to tell each sub-env what kwarg has to be used. + >>> # Both EnvCreator and ParallelEnv offer that possibility. + >>> def make_env(max_steps=4): + ... t = TrajCounter() + ... env = TransformedEnv(CountingEnv(max_steps=max_steps), t) + ... env.transform.transform_observation_spec(env.base_env.observation_spec) + ... return env + >>> make_env_c0 = EnvCreator(make_env) + >>> # Create a variant of the env with different kwargs + >>> make_env_c1 = make_env_c0.make_variant(max_steps=5) + >>> penv = ParallelEnv( + ... 2, + ... [make_env_c0, make_env_c1], + ... mp_start_method="spawn", + ... ) + >>> # Alternatively, pass the kwargs to the ParallelEnv + >>> penv = ParallelEnv( + ... 2, + ... [make_env_c0, make_env_c0], + ... create_env_kwargs=[{"max_steps": 5}, {"max_steps": 4}], + ... mp_start_method="spawn", + ... ) + + """ + + def __init__(self, out_key: NestedKey = "traj_count"): + super().__init__(in_keys=[], out_keys=[out_key]) + self._make_shared_value() + self._initialized = False + + def _make_shared_value(self): + self._traj_count = mp.Value("i", 0) + + def __getstate__(self): + state = super().__getstate__() + state["_traj_count"] = None + return state + + def clone(self): + clone = super().clone() + # All clones share the same _traj_count and lock + clone._traj_count = self._traj_count + return clone + + def _reset( + self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase + ) -> TensorDictBase: + if not self._initialized: + self._initialized = True + rk = self.parent.reset_keys + traj_count_key = self.out_keys[0] + is_str = isinstance(traj_count_key, str) + for _rk in rk: + if is_str and isinstance(_rk, str): + rk = _rk + break + elif ( + not is_str + and isinstance(_rk, tuple) + and _rk[:-1] == traj_count_key[:-1] + ): + rk = _rk + break + else: + raise RuntimeError( + f"Did not find reset key that matched the prefix of the traj counter key. Reset keys: {rk}, traj count: {traj_count_key}" + ) + reset = None + if tensordict is not None: + reset = tensordict.get(rk, default=None) + if reset is None: + reset = torch.ones( + self.container.observation_spec[self.out_keys[0]].shape, + device=tensordict_reset.device, + dtype=torch.bool, + ) + with (self._traj_count): + tc = int(self._traj_count.value) + self._traj_count.value = self._traj_count.value + reset.sum().item() + episodes = torch.arange(tc, tc + reset.sum(), device=self.parent.device) + episodes = torch.masked_scatter( + torch.zeros_like(reset, dtype=episodes.dtype), reset, episodes + ) + tensordict_reset.set(traj_count_key, episodes) + return tensordict_reset + + def _step( + self, tensordict: TensorDictBase, next_tensordict: TensorDictBase + ) -> TensorDictBase: + if not self._initialized: + raise RuntimeError("_step was called before _reset was called.") + next_tensordict.set(self.out_keys[0], tensordict.get(self.out_keys[0])) + return next_tensordict + + def _call(self, tensordict: TensorDictBase) -> TensorDictBase: + raise RuntimeError( + f"{type(self).__name__} can only be called within an environment step or reset." + ) + + def forward(self, tensordict: TensorDictBase) -> TensorDictBase: + raise RuntimeError( + f"{type(self).__name__} can only be called within an environment step or reset." + ) + + def state_dict(self, *args, destination=None, prefix="", keep_vars=False): + return { + "traj_count": int(self._traj_count.value), + } + + def load_state_dict( + self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False + ): + self._traj_count.value *= 0 + self._traj_count.value += state_dict["traj_count"] + + def transform_observation_spec(self, observation_spec: Composite) -> Composite: + if not isinstance(observation_spec, Composite): + raise ValueError( + f"observation_spec was expected to be of type Composite. Got {type(observation_spec)} instead." + ) + full_done_spec = self.parent.output_spec["full_done_spec"] + traj_count_key = self.out_keys[0] + # find a matching done key (there might be more than one) + for done_key in self.parent.done_keys: + # check root + if type(done_key) != type(traj_count_key): + continue + if isinstance(done_key, tuple): + if done_key[:-1] == traj_count_key[:-1]: + shape = full_done_spec[done_key].shape + break + if isinstance(done_key, str): + shape = full_done_spec[done_key].shape + break + + else: + raise KeyError( + f"Could not find root of traj_count key {traj_count_key} in done keys {self.done_keys}." + ) + observation_spec[traj_count_key] = Bounded( + shape=shape, + dtype=torch.int64, + device=observation_spec.device, + low=0, + high=torch.iinfo(torch.int64).max, + ) + return super().transform_observation_spec(observation_spec) From fa64c2f895b62c02a18e9057e5f4ba6dfbb2e7e8 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 4 Nov 2024 13:11:57 +0000 Subject: [PATCH 9/9] [Versioning] v0.6.1 ghstack-source-id: 675a40340ec8e0adaa02aa15fb0d1367931bdce1 Pull Request resolved: https://github.com/pytorch/rl/pull/2533 --- .github/scripts/td_script.sh | 2 +- .github/scripts/version_script.bat | 2 +- .github/workflows/wheels-legacy.yml | 2 +- setup.py | 2 +- version.txt | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/scripts/td_script.sh b/.github/scripts/td_script.sh index 68c8939b9c1..1b6fbe49bd6 100644 --- a/.github/scripts/td_script.sh +++ b/.github/scripts/td_script.sh @@ -1,5 +1,5 @@ #!/bin/bash -export TORCHRL_BUILD_VERSION=0.6.0 +export TORCHRL_BUILD_VERSION=0.6.1 ${CONDA_RUN} pip install git+https://github.com/pytorch/tensordict.git -U diff --git a/.github/scripts/version_script.bat b/.github/scripts/version_script.bat index cb4187207a2..346a5b3182d 100644 --- a/.github/scripts/version_script.bat +++ b/.github/scripts/version_script.bat @@ -1,3 +1,3 @@ @echo off -set TORCHRL_BUILD_VERSION=0.6.0 +set TORCHRL_BUILD_VERSION=0.6.1 echo TORCHRL_BUILD_VERSION is set to %TORCHRL_BUILD_VERSION% diff --git a/.github/workflows/wheels-legacy.yml b/.github/workflows/wheels-legacy.yml index 2de72efa0f8..d0daabe8449 100644 --- a/.github/workflows/wheels-legacy.yml +++ b/.github/workflows/wheels-legacy.yml @@ -35,7 +35,7 @@ jobs: shell: bash run: | python3 -mpip install wheel - TORCHRL_BUILD_VERSION=0.6.0 python3 setup.py bdist_wheel + TORCHRL_BUILD_VERSION=0.6.1 python3 setup.py bdist_wheel - name: Upload wheel for the test-wheel job uses: actions/upload-artifact@v3 with: diff --git a/setup.py b/setup.py index f3456f4e7cf..8633805da54 100644 --- a/setup.py +++ b/setup.py @@ -176,7 +176,7 @@ def _main(argv): if is_nightly: tensordict_dep = "tensordict-nightly" else: - tensordict_dep = "tensordict>=0.6.0" + tensordict_dep = "tensordict>=0.6.1" if is_nightly: version = get_nightly_version() diff --git a/version.txt b/version.txt index a918a2aa18d..ee6cdce3c29 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -0.6.0 +0.6.1