diff --git a/benchmarks/test_replaybuffer_benchmark.py b/benchmarks/test_replaybuffer_benchmark.py index 34116ff9703..6336e7d3461 100644 --- a/benchmarks/test_replaybuffer_benchmark.py +++ b/benchmarks/test_replaybuffer_benchmark.py @@ -173,23 +173,29 @@ def test_rb_populate(benchmark, rb, storage, sampler, size): ) -class create_tensor_rb: - def __init__(self, rb, storage, sampler, size=1_000_000, iters=100): +class create_compiled_tensor_rb: + def __init__( + self, rb, storage, sampler, storage_size, data_size, iters, compilable=False + ): self.storage = storage self.rb = rb self.sampler = sampler - self.size = size + self.storage_size = storage_size + self.data_size = data_size self.iters = iters + self.compilable = compilable def __call__(self): kwargs = {} if self.sampler is not None: kwargs["sampler"] = self.sampler() if self.storage is not None: - kwargs["storage"] = self.storage(10 * self.size) + kwargs["storage"] = self.storage( + self.storage_size, compilable=self.compilable + ) - rb = self.rb(batch_size=3, **kwargs) - data = torch.randn(self.size, 1) + rb = self.rb(batch_size=3, compilable=self.compilable, **kwargs) + data = torch.randn(self.data_size, 1) return ((rb, data, self.iters), {}) @@ -210,21 +216,32 @@ def fn(td): @pytest.mark.parametrize( - "rb,storage,sampler,size,iters,compiled", + "rb,storage,sampler,storage_size,data_size,iters,compiled", [ - [ReplayBuffer, LazyTensorStorage, RandomSampler, 1000, 100, True], - [ReplayBuffer, LazyTensorStorage, RandomSampler, 1000, 100, False], + [ReplayBuffer, LazyTensorStorage, RandomSampler, 10_000, 10_000, 100, True], + [ReplayBuffer, LazyTensorStorage, RandomSampler, 10_000, 10_000, 100, False], + [ReplayBuffer, LazyTensorStorage, RandomSampler, 100_000, 10_000, 100, True], + [ReplayBuffer, LazyTensorStorage, RandomSampler, 100_000, 10_000, 100, False], + [ReplayBuffer, LazyTensorStorage, RandomSampler, 1_000_000, 10_000, 100, True], + [ReplayBuffer, LazyTensorStorage, RandomSampler, 1_000_000, 10_000, 100, False], ], ) -def test_rb_extend_sample(benchmark, rb, storage, sampler, size, iters, compiled): +def test_rb_extend_sample( + benchmark, rb, storage, sampler, storage_size, data_size, iters, compiled +): + if compiled: + torch._dynamo.reset_code_caches() + benchmark.pedantic( extend_and_sample_compiled if compiled else extend_and_sample, - setup=create_tensor_rb( + setup=create_compiled_tensor_rb( rb=rb, storage=storage, sampler=sampler, - size=size, + storage_size=storage_size, + data_size=data_size, iters=iters, + compilable=compiled, ), iterations=1, warmup_rounds=10, diff --git a/docs/source/reference/data.rst b/docs/source/reference/data.rst index 9b072cc9664..ce6c9e487b9 100644 --- a/docs/source/reference/data.rst +++ b/docs/source/reference/data.rst @@ -985,11 +985,13 @@ TorchRL offers a set of classes and functions that can be used to represent tree BinaryToDecimal HashToInt + MCTSForeset QueryModule RandomProjectionHash SipHash TensorDictMap TensorMap + Tree Reinforcement Learning From Human Feedback (RLHF) diff --git a/setup.py b/setup.py index 8633805da54..75a2486815e 100644 --- a/setup.py +++ b/setup.py @@ -275,7 +275,6 @@ def _main(argv): extras_require=extra_requires, zip_safe=False, classifiers=[ - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", diff --git a/test/test_cost.py b/test/test_cost.py index 0066c024776..0b36f5b8961 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -7650,6 +7650,7 @@ def _create_mock_actor( observation_key="observation", sample_log_prob_key="sample_log_prob", composite_action_dist=False, + aggregate_probabilities=True, ): # Actor action_spec = Bounded( @@ -7668,7 +7669,7 @@ def _create_mock_actor( "action1": (action_key, "action1"), }, log_prob_key=sample_log_prob_key, - aggregate_probabilities=True, + aggregate_probabilities=aggregate_probabilities, ) module_out_keys = [ ("params", "action1", "loc"), @@ -8038,6 +8039,96 @@ def test_ppo( assert counter == 2 actor.zero_grad() + @pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss)) + @pytest.mark.parametrize("gradient_mode", (True, False)) + @pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None)) + @pytest.mark.parametrize("device", get_default_devices()) + @pytest.mark.parametrize("td_est", list(ValueEstimators) + [None]) + @pytest.mark.parametrize("functional", [True, False]) + def test_ppo_composite_no_aggregate( + self, loss_class, device, gradient_mode, advantage, td_est, functional + ): + torch.manual_seed(self.seed) + td = self._create_seq_mock_data_ppo(device=device, composite_action_dist=True) + + actor = self._create_mock_actor( + device=device, + composite_action_dist=True, + aggregate_probabilities=False, + ) + value = self._create_mock_value(device=device) + if advantage == "gae": + advantage = GAE( + gamma=0.9, lmbda=0.9, value_network=value, differentiable=gradient_mode + ) + elif advantage == "vtrace": + advantage = VTrace( + gamma=0.9, + value_network=value, + actor_network=actor, + differentiable=gradient_mode, + ) + elif advantage == "td": + advantage = TD1Estimator( + gamma=0.9, value_network=value, differentiable=gradient_mode + ) + elif advantage == "td_lambda": + advantage = TDLambdaEstimator( + gamma=0.9, lmbda=0.9, value_network=value, differentiable=gradient_mode + ) + elif advantage is None: + pass + else: + raise NotImplementedError + + loss_fn = loss_class( + actor, + value, + loss_critic_type="l2", + functional=functional, + ) + if advantage is not None: + advantage(td) + else: + if td_est is not None: + loss_fn.make_value_estimator(td_est) + + loss = loss_fn(td) + if isinstance(loss_fn, KLPENPPOLoss): + kl = loss.pop("kl_approx") + assert (kl != 0).any() + + loss_critic = loss["loss_critic"] + loss_objective = loss["loss_objective"] + loss.get("loss_entropy", 0.0) + loss_critic.backward(retain_graph=True) + # check that grads are independent and non null + named_parameters = loss_fn.named_parameters() + counter = 0 + for name, p in named_parameters: + if p.grad is not None and p.grad.norm() > 0.0: + counter += 1 + assert "actor" not in name + assert "critic" in name + if p.grad is None: + assert ("actor" in name) or ("target_" in name) + assert ("critic" not in name) or ("target_" in name) + assert counter == 2 + + value.zero_grad() + loss_objective.backward() + counter = 0 + named_parameters = loss_fn.named_parameters() + for name, p in named_parameters: + if p.grad is not None and p.grad.norm() > 0.0: + counter += 1 + assert "actor" in name + assert "critic" not in name + if p.grad is None: + assert ("actor" not in name) or ("target_" in name) + assert ("critic" in name) or ("target_" in name) + assert counter == 2 + actor.zero_grad() + @pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss)) @pytest.mark.parametrize("gradient_mode", (True,)) @pytest.mark.parametrize("device", get_default_devices()) diff --git a/test/test_rb.py b/test/test_rb.py index c14ccb64c04..6ceefef29d1 100644 --- a/test/test_rb.py +++ b/test/test_rb.py @@ -178,18 +178,24 @@ ) @pytest.mark.parametrize("size", [3, 5, 100]) class TestComposableBuffers: - def _get_rb(self, rb_type, size, sampler, writer, storage): + def _get_rb(self, rb_type, size, sampler, writer, storage, compilable=False): if storage is not None: - storage = storage(size) + storage = storage(size, compilable=compilable) sampler_args = {} if sampler is samplers.PrioritizedSampler: sampler_args = {"max_capacity": size, "alpha": 0.8, "beta": 0.9} sampler = sampler(**sampler_args) - writer = writer() - rb = rb_type(storage=storage, sampler=sampler, writer=writer, batch_size=3) + writer = writer(compilable=compilable) + rb = rb_type( + storage=storage, + sampler=sampler, + writer=writer, + batch_size=3, + compilable=compilable, + ) return rb def _get_datum(self, datatype): @@ -421,8 +427,9 @@ def data_iter(): # # Our Windows CI jobs do not have "cl", so skip this test. @pytest.mark.skipif(_os_is_windows, reason="windows tests do not support compile") + @pytest.mark.parametrize("avoid_max_size", [False, True]) def test_extend_sample_recompile( - self, rb_type, sampler, writer, storage, size, datatype + self, rb_type, sampler, writer, storage, size, datatype, avoid_max_size ): if rb_type is not ReplayBuffer: pytest.skip( @@ -443,15 +450,26 @@ def test_extend_sample_recompile( torch._dynamo.reset_code_caches() - storage_size = 10 * size + # Number of times to extend the replay buffer + num_extend = 10 + data_size = size + + # These two cases are separated because when the max storage size is + # reached, the code execution path changes, causing necessary + # recompiles. + if avoid_max_size: + storage_size = (num_extend + 1) * data_size + else: + storage_size = 2 * data_size + rb = self._get_rb( rb_type=rb_type, sampler=sampler, writer=writer, storage=storage, size=storage_size, + compilable=True, ) - data_size = size data = self._get_data(datatype, size=data_size) @torch.compile @@ -459,12 +477,9 @@ def extend_and_sample(data): rb.extend(data) return rb.sample() - # Number of times to extend the replay buffer - num_extend = 30 - - # NOTE: The first two calls to 'extend' and 'sample' currently cause - # recompilations, so avoid capturing those for now. - num_extend_before_capture = 2 + # NOTE: The first three calls to 'extend' and 'sample' can currently + # cause recompilations, so avoid capturing those. + num_extend_before_capture = 3 for _ in range(num_extend_before_capture): extend_and_sample(data) @@ -477,12 +492,12 @@ def extend_and_sample(data): for _ in range(num_extend - num_extend_before_capture): extend_and_sample(data) - assert len(rb) == storage_size - assert len(records) == 0 - finally: torch._logging.set_logs() + assert len(rb) == min((num_extend * data_size), storage_size) + assert len(records) == 0 + def test_sample(self, rb_type, sampler, writer, storage, size, datatype): if rb_type is RemoteTensorDictReplayBuffer and _os_is_windows: pytest.skip( @@ -806,6 +821,52 @@ def test_storage_state_dict(self, storage_in, storage_out, init_out, backend): s = new_replay_buffer.sample() assert (s.exclude("index") == 1).all() + @pytest.mark.skipif( + TORCH_VERSION < version.parse("2.5.0"), reason="requires Torch >= 2.5.0" + ) + @pytest.mark.skipif(_os_is_windows, reason="windows tests do not support compile") + # This test checks if the `torch._dynamo.disable` wrapper around + # `TensorStorage._rand_given_ndim` is still necessary. + def test__rand_given_ndim_recompile(self): + torch._dynamo.reset_code_caches() + + # Number of times to extend the replay buffer + num_extend = 10 + data_size = 100 + storage_size = (num_extend + 1) * data_size + sample_size = 3 + + storage = LazyTensorStorage(storage_size, compilable=True) + sampler = RandomSampler() + + # Override to avoid the `torch._dynamo.disable` wrapper + storage._rand_given_ndim = storage._rand_given_ndim_impl + + @torch.compile + def extend_and_sample(data): + storage.set(torch.arange(data_size) + len(storage), data) + return sampler.sample(storage, sample_size) + + data = torch.randint(100, (data_size, 1)) + + try: + torch._logging.set_logs(recompiles=True) + records = [] + capture_log_records(records, "torch._dynamo", "recompiles") + + for _ in range(num_extend): + extend_and_sample(data) + + finally: + torch._logging.set_logs() + + assert len(storage) == num_extend * data_size + assert len(records) == 8, ( + "If this ever decreases, that's probably good news and the " + "`torch._dynamo.disable` wrapper around " + "`TensorStorage._rand_given_ndim` can be removed." + ) + @pytest.mark.parametrize("storage_type", [LazyMemmapStorage, LazyTensorStorage]) def test_extend_lazystack(self, storage_type): diff --git a/test/test_storage_map.py b/test/test_storage_map.py index 7473a241140..b2b1a3ed8cb 100644 --- a/test/test_storage_map.py +++ b/test/test_storage_map.py @@ -5,12 +5,13 @@ import argparse import functools import importlib.util +from typing import Tuple import pytest import torch -from tensordict import TensorDict +from tensordict import assert_close, TensorDict from torchrl.data import LazyTensorStorage, ListStorage, MCTSForest from torchrl.data.map import ( BinaryToDecimal, @@ -19,7 +20,7 @@ SipHash, TensorDictMap, ) -from torchrl.envs import GymEnv, PendulumEnv, UnsqueezeTransform, CatTensors, StepCounter +from torchrl.envs import GymEnv _has_gym = importlib.util.find_spec("gymnasium", None) or importlib.util.find_spec( "gym", None @@ -237,65 +238,236 @@ def test_map_rollout(self): assert contains[: rollout.shape[-1]].all() assert not contains[rollout.shape[-1] :].any() + class TestMCTSForest: - def test_forest_build(self): - forest = MCTSForest() - env = PendulumEnv() - obs_keys = list(env.observation_spec.keys(True, True)) - state_keys = set(env.full_state_spec.keys(True, True)) - set(obs_keys) - # Appending transforms to get an "observation" key that concatenates the observations together - env = env.append_transform( - UnsqueezeTransform( - in_keys=obs_keys, - out_keys=[("unsqueeze", key) for key in obs_keys], - dim=-1 - ) - ) - env = env.append_transform( - CatTensors([("unsqueeze", key) for key in obs_keys], "observation") + def dummy_rollouts(self) -> Tuple[TensorDict, ...]: + """ + ├── 0 + │ ├── 16 + │ ├── 17 + │ ├── 18 + │ ├── 19 + │ └── 20 + ├── 1 + ├── 2 + ├── 3 + │ ├── 6 + │ ├── 7 + │ ├── 8 + │ ├── 9 + │ └── 10 + ├── 4 + │ ├── 11 + │ ├── 12 + │ ├── 13 + │ │ ├── 21 + │ │ ├── 22 + │ │ ├── 23 + │ │ ├── 24 + │ │ └── 25 + │ ├── 14 + │ └── 15 + └── 5 + + """ + + states0 = torch.arange(6) + actions0 = torch.full((5,), 0) + + states1 = torch.cat([torch.tensor([3]), torch.arange(6, 11)]) + actions1 = torch.full((5,), 1) + + states2 = torch.cat([torch.tensor([4]), torch.arange(11, 16)]) + actions2 = torch.full((5,), 2) + + states3 = torch.cat([torch.tensor([0]), torch.arange(16, 21)]) + actions3 = torch.full((5,), 3) + + states4 = torch.cat([torch.tensor([13]), torch.arange(21, 26)]) + actions4 = torch.full((5,), 4) + + return ( + self._make_td(states0, actions0), + self._make_td(states1, actions1), + self._make_td(states2, actions2), + self._make_td(states3, actions3), + self._make_td(states4, actions4), ) - env = env.append_transform(StepCounter()) - env.set_seed(0) - # Get a reset state, then make a rollout out of it - reset_state = env.reset() - rollout0 = env.rollout(6, auto_reset=False, tensordict=reset_state.clone()) - # Append the rollout to the forest. We're removing the state entries for clarity - rollout0 = rollout0.copy() - rollout0.exclude(*state_keys, inplace=True) - rollout0.get("next").exclude(*state_keys, inplace=True) - forest.extend(rollout0) - # The forest should have 6 elements (the length of the rollout) - assert len(forest) == 6 - # Let's make another rollout from the same reset state - rollout1 = env.rollout(6, auto_reset=False, tensordict=reset_state.clone()) - rollout1.exclude(*state_keys, inplace=True) - rollout1.get("next").exclude(*state_keys, inplace=True) - forest.extend(rollout1) - assert len(forest) == 12 - r = rollout0[0] - tree = forest.get_tree(r) - tree.shape - # (torch.Size([]), _StringKeys(dict_keys(['data_content', 'children', 'count']))) - tree.count - # tensor(2, dtype=torch.int32) - tree.children.shape - # (torch.Size([2]), - # _StringKeys(dict_keys(['node', 'action', 'reward', 'index', 'hash']))) - starttd = rollout1[2] - rollout2 = env.rollout(6, auto_reset=False, tensordict=starttd) - rollout2.exclude(*state_keys, inplace=True) - rollout2.get("next").exclude(*state_keys, inplace=True) - forest.extend(rollout2) - assert len(forest) == 18 - r = rollout0[0] - tree = forest.get_tree(r) - assert (tree.children.node.children.node.count == torch.tensor([[1], [2]])).all() - assert tree.children.node.children.node.shape == torch.Size((2, 1)) - tree.children.node.children.node.children.shape, tree.children.node.children.node.children[0].shape, tree.children.node.children.node.children[1].shape - - - def test_forest_extend_and_get(self): - ... + + def _state0(self) -> TensorDict: + return self.dummy_rollouts()[0][0] + + @staticmethod + def _make_td(state: torch.Tensor, action: torch.Tensor) -> TensorDict: + done = torch.zeros_like(action, dtype=torch.bool).unsqueeze(-1) + reward = action.clone() + + return TensorDict( + { + "observation": state[:-1], + "action": action, + "done": torch.zeros_like(done), + "next": { + "observation": state[1:], + "done": done, + "reward": reward, + }, + } + ).auto_batch_size_() + + def _make_forest(self) -> MCTSForest: + r0, r1, r2, r3, r4 = self.dummy_rollouts() + assert r0.shape + forest = MCTSForest(consolidated=True) + forest.extend(r0) + forest.extend(r1) + forest.extend(r2) + forest.extend(r3) + forest.extend(r4) + return forest + + def _make_forest_intersect(self) -> MCTSForest: + """ + ├── 0 + │ ├── 16 + │ ├── 17 + │ ├── 18 + │ ├── 19───────│ + │ │ └── 26 │ + │ └── 20 │ + ├── 1 │ + ├── 2 │ + ├── 3 │ + │ ├── 6 │ + │ ├── 7 │ + │ ├── 8 │ + │ ├── 9 │ + │ └── 10 │ + ├── 4 │ + │ ├── 11 │ + │ ├── 12 │ + │ ├── 13 │ + │ │ ├── 21 │ + │ │ ├── 22 │ + │ │ ├── 23 │ + │ │ ├── 24 ──│ + │ │ └── 25 + │ ├── 14 + │ └── 15 + └── 5 + """ + forest = self._make_forest() + states5 = torch.cat([torch.tensor([24]), torch.tensor([19, 26])]) + actions5 = torch.full((2,), 5) + rollout5 = self._make_td(states5, actions5) + forest.extend(rollout5) + return forest + + def test_forest_build(self): + r0, *_ = self.dummy_rollouts() + forest = self._make_forest() + tree = forest.get_tree(r0[0]) + + def test_forest_vertices(self): + r0, *_ = self.dummy_rollouts() + forest = self._make_forest() + + tree = forest.get_tree(r0[0]) + assert tree.num_vertices() == 9 # (0, 20, 3, 10, 4, 13, 25, 15, 5) + + tree = forest.get_tree(r0[0], compact=False) + assert tree.num_vertices() == 26 + + def test_forest_rebuild_rollout(self): + r0, r1, r2, r3, r4 = self.dummy_rollouts() + forest = self._make_forest() + + tree = forest.get_tree(r0[0]) + assert_close(tree.rollout_from_path((0, 0, 0)), r0, intersection=True) + assert_close(tree.rollout_from_path((0, 1))[-5:], r1, intersection=True) + assert_close(tree.rollout_from_path((0, 0, 1, 0))[-5:], r2, intersection=True) + assert_close(tree.rollout_from_path((1,))[-5:], r3, intersection=True) + assert_close(tree.rollout_from_path((0, 0, 1, 1))[-5:], r4, intersection=True) + + def test_forest_check_hashes(self): + r0, *_ = self.dummy_rollouts() + forest = self._make_forest() + tree = forest.get_tree(r0[0]) + nodes = range(tree.num_vertices()) + hashes = set() + for n in nodes: + vertex = tree.get_vertex_by_id(n) + node_hash = vertex.hash + if node_hash is not None: + assert isinstance(node_hash, int) + hashes.add(node_hash) + else: + assert vertex is tree + assert len(hashes) == tree.num_vertices() - 1 + + def test_forest_check_ids(self): + r0, *_ = self.dummy_rollouts() + forest = self._make_forest() + tree = forest.get_tree(r0[0]) + nodes = range(tree.num_vertices()) + for n in nodes: + vertex = tree.get_vertex_by_id(n) + node_id = vertex.node_id + assert isinstance(node_id, int) + assert node_id == n + + # Ideally, we'd like to have only views but because we index the storage with a tensor + # we actually get regular, single-storage tensors + # def test_forest_view(self): + # import tensordict.base + # r0, *_ = self.dummy_rollouts() + # forest = self._make_forest() + # tree = forest.get_tree(r0[0]) + # dataptr = set() + # # Check that all tensors point to the same storage (ie, that we only have views) + # for k, v in tree.items(True, True, is_leaf=tensordict.base._NESTED_TENSORS_AS_LISTS): + # if isinstance(k, tuple) and "rollout" in k: + # dataptr.add(v.storage().data_ptr()) + # assert len(dataptr) == 1, k + + def test_forest_intersect(self): + state0 = self._state0() + forest = self._make_forest_intersect() + tree = forest.get_tree(state0) + subtree = forest.get_tree(TensorDict(observation=19)) + + def make_labels(tree): + if tree.rollout is not None: + s = torch.cat( + [ + tree.rollout["observation"][:1], + tree.rollout["next", "observation"], + ] + ) + s = s.tolist() + return f"{tree.node_id}: {s}" + return f"{tree.node_id}" + + # subtree.plot(make_labels=make_labels) + # tree.plot(make_labels=make_labels) + assert tree.get_vertex_by_id(2).num_children == 2 + assert tree.get_vertex_by_id(1).num_children == 2 + assert tree.get_vertex_by_id(3).num_children == 2 + assert tree.get_vertex_by_id(8).num_children == 2 + assert tree.get_vertex_by_id(10).num_children == 2 + assert tree.get_vertex_by_id(12).num_children == 2 + + # Test contains + assert subtree in tree + + def test_forest_intersect_vertices(self): + state0 = self._state0() + forest = self._make_forest_intersect() + tree = forest.get_tree(state0) + assert len(tree.vertices(key_type="path")) > len(tree.vertices(key_type="hash")) + assert len(tree.vertices(key_type="id")) == len(tree.vertices(key_type="hash")) + with pytest.raises(ValueError, match="key_type must be"): + tree.vertices(key_type="another key type") if __name__ == "__main__": diff --git a/test/test_transforms.py b/test/test_transforms.py index 59612d2bd65..a5a4fad4e40 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -159,6 +159,7 @@ from torchrl.envs.transforms.vip import _VIPNet, VIPRewardTransform from torchrl.envs.utils import check_env_specs, step_mdp from torchrl.modules import GRUModule, LSTMModule, MLP, ProbabilisticActor, TanhNormal +from torchrl.modules.utils import get_primers_from_module IS_WIN = platform == "win32" if IS_WIN: @@ -7163,6 +7164,33 @@ def test_dict_default_value(self): rollout_td.get(("next", "mykey2")) == torch.tensor(1, dtype=torch.int64) ).all + def test_spec_shape_inplace_correction(self): + hidden_size = input_size = num_layers = 2 + model = GRUModule( + input_size, hidden_size, num_layers, in_key="observation", out_key="action" + ) + env = TransformedEnv( + SerialEnv(2, lambda: GymEnv("Pendulum-v1")), + ) + # These primers do not have the leading batch dimension + # since model is agnostic to batch dimension that will be used. + primers = get_primers_from_module(model) + for primer in primers.primers: + assert primers.primers.get(primer).shape == torch.Size( + [num_layers, hidden_size] + ) + env.append_transform(primers) + + # Reset should add the batch dimension to the primers + # since the parent exists and is batch_locked. + td = env.reset() + + for primer in primers.primers: + assert primers.primers.get(primer).shape == torch.Size( + [2, num_layers, hidden_size] + ) + assert td.get(primer).shape == torch.Size([2, num_layers, hidden_size]) + class TestTimeMaxPool(TransformBase): @pytest.mark.parametrize("T", [2, 4]) diff --git a/torchrl/_extension.py b/torchrl/_extension.py index a9e52dbf9a4..61eedb46418 100644 --- a/torchrl/_extension.py +++ b/torchrl/_extension.py @@ -6,6 +6,13 @@ import importlib.util import warnings +from packaging.version import parse + +try: + from .version import __version__ +except ImportError: + __version__ = None + def is_module_available(*modules: str) -> bool: """Returns if a top-level module with :attr:`name` exists *without** importing it. @@ -24,8 +31,28 @@ def _init_extension(): return -EXTENSION_WARNING = ( - "Failed to import torchrl C++ binaries. Some modules (eg, prioritized replay buffers) may not work with your installation. " - "If you installed TorchRL from PyPI, please report the bug on TorchRL github. " - "If you installed TorchRL locally and/or in development mode, check that you have all the required compiling packages." -) +def _is_nightly(version): + if version is None: + return True + parsed_version = parse(version) + return parsed_version.local is not None + + +if _is_nightly(__version__): + EXTENSION_WARNING = ( + "Failed to import torchrl C++ binaries. Some modules (eg, prioritized replay buffers) may not work with your installation. " + "You seem to be using the nightly version of TorchRL. If this is a local install, there might be an issue with " + "the local installation. Here are some tips to debug this:\n" + " - make sure ninja and cmake were installed\n" + " - make sure you ran `python setup.py clean && python setup.py develop` and that no error was raised\n" + " - make sure the version of PyTorch you are using matches the one that was present in your virtual env during " + "setup." + ) + +else: + EXTENSION_WARNING = ( + "Failed to import torchrl C++ binaries. Some modules (eg, prioritized replay buffers) may not work with your installation. " + "This is likely due to a discrepancy between your package version and the PyTorch version. Make sure both are compatible. " + "Usually, torchrl majors follow the pytorch majors within a few days around the release. " + "For instance, TorchRL 0.5 requires PyTorch 2.4.0, and TorchRL 0.6 requires PyTorch 2.5.0." + ) diff --git a/torchrl/_utils.py b/torchrl/_utils.py index 3af44ee0ed7..31e00614fd9 100644 --- a/torchrl/_utils.py +++ b/torchrl/_utils.py @@ -252,6 +252,11 @@ class implement_for: Keyword Args: class_method (bool, optional): if ``True``, the function will be written as a class method. Defaults to ``False``. + compilable (bool, optional): If ``False``, the module import happens + only on the first call to the wrapped function. If ``True``, the + module import happens when the wrapped function is initialized. This + allows the wrapped function to work well with ``torch.compile``. + Defaults to ``False``. Examples: >>> @implement_for("gym", "0.13", "0.14") @@ -290,11 +295,13 @@ def __init__( to_version: str = None, *, class_method: bool = False, + compilable: bool = False, ): self.module_name = module_name self.from_version = from_version self.to_version = to_version self.class_method = class_method + self._compilable = compilable implement_for._setters.append(self) @staticmethod @@ -386,18 +393,27 @@ def __call__(self, fn): self.fn = fn implement_for._lazy_impl[self.func_name].append(self._call) - @wraps(fn) - def _lazy_call_fn(*args, **kwargs): - # first time we call the function, we also do the replacement. - # This will cause the imports to occur only during the first call to fn + if self._compilable: + _call_fn = self._delazify(self.func_name) - result = self._delazify(self.func_name)(*args, **kwargs) - return result + if self.class_method: + return classmethod(_call_fn) - if self.class_method: - return classmethod(_lazy_call_fn) + return _call_fn + else: + + @wraps(fn) + def _lazy_call_fn(*args, **kwargs): + # first time we call the function, we also do the replacement. + # This will cause the imports to occur only during the first call to fn + + result = self._delazify(self.func_name)(*args, **kwargs) + return result + + if self.class_method: + return classmethod(_lazy_call_fn) - return _lazy_call_fn + return _lazy_call_fn def _call(self): diff --git a/torchrl/data/__init__.py b/torchrl/data/__init__.py index 88953ced3b2..639ed820e86 100644 --- a/torchrl/data/__init__.py +++ b/torchrl/data/__init__.py @@ -6,14 +6,13 @@ from .map import ( BinaryToDecimal, HashToInt, - MCTSChildren, MCTSForest, - MCTSNode, QueryModule, RandomProjectionHash, SipHash, TensorDictMap, TensorMap, + Tree, ) from .postprocs import MultiStep from .replay_buffers import ( diff --git a/torchrl/data/map/__init__.py b/torchrl/data/map/__init__.py index 7ef1f61a845..c9bc25477c2 100644 --- a/torchrl/data/map/__init__.py +++ b/torchrl/data/map/__init__.py @@ -6,4 +6,4 @@ from .hash import BinaryToDecimal, RandomProjectionHash, SipHash from .query import HashToInt, QueryModule from .tdstorage import TensorDictMap, TensorMap -from .tree import MCTSChildren, MCTSForest, MCTSNode +from .tree import MCTSForest, Tree diff --git a/torchrl/data/map/hash.py b/torchrl/data/map/hash.py index 963f3b225fc..01988dc43be 100644 --- a/torchrl/data/map/hash.py +++ b/torchrl/data/map/hash.py @@ -9,6 +9,7 @@ import torch from torch.nn import Module + class BinaryToDecimal(Module): """A Module to convert binaries encoded tensors to decimals. diff --git a/torchrl/data/map/query.py b/torchrl/data/map/query.py index a26d80cc334..ff0fb4dfe24 100644 --- a/torchrl/data/map/query.py +++ b/torchrl/data/map/query.py @@ -169,10 +169,10 @@ def __init__( def forward( self, tensordict: TensorDictBase, - *, + *, extend: bool = True, write_hash: bool = True, - clone: bool| None = None, + clone: bool | None = None, ) -> TensorDictBase: hash_values = [] diff --git a/torchrl/data/map/tdstorage.py b/torchrl/data/map/tdstorage.py index 8cf7978cb86..a601f1e3261 100644 --- a/torchrl/data/map/tdstorage.py +++ b/torchrl/data/map/tdstorage.py @@ -171,10 +171,12 @@ def from_tensordict_pair( dest, in_keys: List[NestedKey], out_keys: List[NestedKey] | None = None, + max_size: int = 1000, storage_constructor: type | None = None, hash_module: Callable | None = None, collate_fn: Callable[[Any], Any] | None = None, write_fn: Callable[[Any, Any], Any] | None = None, + consolidated: bool | None = None, ): """Creates a new TensorDictStorage from a pair of tensordicts (source and dest) using pre-defined rules of thumb. @@ -185,6 +187,8 @@ def from_tensordict_pair( out_keys (List[NestedKey]): a list of keys to return in the output tensordict. All keys absent from out_keys, even if present in ``dest``, will not be stored in the storage. Defaults to ``None`` (all keys are registered). + max_size (int, optional): the maximum number of elements in the storage. Ignored if the + ``storage_constructor`` is passed. Defaults to ``1000``. storage_constructor (type, optional): a type of tensor storage. Defaults to :class:`~tensordict.nn.storage.LazyDynamicStorage`. Other options include :class:`~tensordict.nn.storage.FixedStorage`. @@ -195,6 +199,8 @@ def from_tensordict_pair( storage. Defaults to a custom value for each known storage type (stack for :class:`~torchrl.data.ListStorage`, identity for :class:`~torchrl.data.TensorStorage` subtypes and others). + consolidated (bool, optional): whether to consolidate the storage in a single storage tensor. + Defaults to ``False``. Examples: >>> # The following example requires torchrl and gymnasium to be installed @@ -242,7 +248,13 @@ def from_tensordict_pair( # Build key_to_storage if storage_constructor is None: - storage_constructor = functools.partial(LazyTensorStorage, 1000) + storage_constructor = functools.partial( + LazyTensorStorage, max_size, consolidated=bool(consolidated) + ) + elif consolidated is not None: + storage_constructor = functools.partial( + storage_constructor, consolidated=consolidated + ) storage = storage_constructor() result = cls( query_module=query_module, @@ -257,7 +269,9 @@ def clear(self) -> None: for mem in self.storage.values(): mem.clear() - def _to_index(self, item: TensorDictBase, extend: bool, clone: bool | None=None) -> torch.Tensor: + def _to_index( + self, item: TensorDictBase, extend: bool, clone: bool | None = None + ) -> torch.Tensor: item = self.query_module(item, extend=extend, clone=clone) return item[self.index_key] diff --git a/torchrl/data/map/tree.py b/torchrl/data/map/tree.py index b4a9fd8048c..efbf6f79553 100644 --- a/torchrl/data/map/tree.py +++ b/torchrl/data/map/tree.py @@ -4,54 +4,288 @@ # LICENSE file in the root directory of this source tree. from __future__ import annotations -import weakref -from typing import List +from collections import deque + +from typing import Any, Callable, Dict, List, Literal, Tuple import torch -from tensordict import LazyStackedTensorDict, NestedKey, tensorclass, TensorDict -from torchrl.data.replay_buffers.storages import ListStorage +from tensordict import ( + merge_tensordicts, + NestedKey, + TensorClass, + TensorDict, + TensorDictBase, +) from torchrl.data.map.tdstorage import TensorDictMap +from torchrl.data.map.utils import _plot_plotly_box, _plot_plotly_tree +from torchrl.data.replay_buffers.storages import ListStorage from torchrl.envs.common import EnvBase -@tensorclass -class MCTSNode: - """An MCTS node. +class Tree(TensorClass["nocast"]): + """Representation of a single MCTS (Monte Carlo Tree Search) Tree. + + This class encapsulates the data and behavior of a tree node in an MCTS algorithm. + It includes attributes for storing information about the node, such as its children, + visit count, and rollout data. Methods are provided for traversing the tree, + computing statistics, and visualizing the tree structure. + + It is somewhat indistinguishable from a node or a vertex - we use the term "Tree" when talking about + a node with children, "node" or "vertex" when talking about a place in the tree where a branching occurs. + A node in the tree is defined primarily by its ``hash`` value. Usually, a ``hash`` is determined by a unique + combination of state (or observation) and action. If one observation (found in the ``node`` attribute) has more than + one action associated, each branch will be stored in the ``subtree`` attribute as a stack of ``Tree`` instances. + + Attributes: + count (int): The number of visits to this node. + index (torch.Tensor): Indices of the child nodes in the data map. + hash (torch.Tensor): A hash value for this node. + It may be the case that ``hash`` is ``None`` in the specific case where the root of the tree + has more than one action associated. In that case, each subtree branch will have a different action + associated and a hash correspoding to the ``(observation, action)`` pair. + node_id (int): A unique identifier for this node. + rollout (TensorDict): Rollout data following the observation encoded in this node, in a TED format. + If there are multiple actions taken at this node, subtrees are stored in the corresponding + entry. Rollouts can be reconstructed using the :meth:`~.rollout_from_path` method. + node (TensorDict): Data defining this node (e.g., observations) before the next branching. + Entries usually matches the ``in_keys`` in ``MCTSForeset.node_map``. + subtree (Tree): A stack of subtrees produced when actions are taken. + num_children (int): The number of child nodes (read-only). + is_terminal (bool): whether the tree has children nodes (read-only). + If the tree is compact, ``is_terminal == True`` means that there are more than one child node in + ``self.subtree``. + + Methods: + __contains__: Whether another tree can be found in the tree. + vertices: Returns a dictionary containing all vertices in the tree. Keys must be paths, ids or hashes. + num_vertices: Returns the total number of vertices in the tree, with or without duplicates. + edges: Returns a list of edges in the tree. + valid_paths: Yields all valid paths in the tree. + max_length: Returns the maximum length of any path in the tree. + rollout_from_path: Reconstructs a rollout from a given path. + plot: Visualizes the tree using a specified backend and figure type. + get_node_by_id: returns the vertex given by its id in the tree. + get_node_by_hash: returns the vertex given by its hash in the forest. + + """ + + count: int = None + index: torch.Tensor | None = None + # The hash is None if the node has more than one action associated + hash: int | None = None + node_id: int | None = None - The batch-size of a root node is indicative of the batch-size of the tree: - each indexed element of a ``Node`` corresponds to a separate tree. + # rollout following the observation encoded in node, in a TorchRL (TED) format + rollout: TensorDict | None = None - A node is characterized by its data (a tensordict with keys such as ``"observation"``, - or ``"done"``), a ``children`` field containing all the branches from that node - (one per action taken), and a ``count`` tensor indicating how many times this node - has been visited. + # The data specifying the node + node: TensorDict | None = None - """ + # Stack of subtrees. A subtree is produced when an action is taken. + subtree: "Tree" = None - data_content: TensorDict - children: MCTSChildren | None = None - count: torch.Tensor | None = None - forest: MCTSForest | None = None + @property + def num_children(self) -> int: + """Number of children of this node. + + Equates to the number of elements in the ``self.subtree`` stack. + """ + return len(self.subtree) if self.subtree is not None else 0 + + @property + def is_terminal(self): + """Returns True if the the tree has no children nodes.""" + return self.subtree is None + + def get_vertex_by_id(self, id: int) -> Tree: + """Goes through the tree and returns the node corresponding the given id.""" + q = deque() + q.append(self) + while len(q): + tree = q.popleft() + if tree.node_id == id: + return tree + if tree.subtree is not None: + q.extend(tree.subtree.unbind(0)) + raise ValueError(f"Node with id {id} not found.") + + def get_vertex_by_hash(self, hash: int) -> Tree: + """Goes through the tree and returns the node corresponding the given hash.""" + q = deque() + q.append(self) + while len(q): + tree = q.popleft() + if tree.hash == hash: + return tree + if tree.subtree is not None: + q.extend(tree.subtree.unbind(0)) + raise ValueError(f"Node with hash {hash} not found.") + + def __contains__(self, other: Tree) -> bool: + hash = other.hash + for vertex in self.vertices().values(): + if vertex.hash == hash: + return True + else: + return False - def plot(self, backend="plotly", info=None): - forest = self.forest - if isinstance(forest, weakref.ref): - forest = forest() - return forest.plot(self, backend=backend, info=info) + def vertices( + self, *, key_type: Literal["id", "hash", "path"] = "hash" + ) -> Dict[int | Tuple[int], Tree]: + """Returns a map containing the vertices of the Tree. -@tensorclass -class MCTSChildren: - """The children of a node. + Keyword args: + key_type (Literal["id", "hash", "path"], optional): Specifies the type of key to use for the vertices. - This class contains data of the same batch-size: the ``index`` and ``hash`` - associated with each ``node``. Therefore, each indexed element of a ``Children`` - corresponds to one child with its associated index. Actions . + - "id": Use the vertex ID as the key. + - "hash": Use a hash of the vertex as the key. + - "path": Use the path to the vertex as the key. This may lead to a dictionary with a longer length than + when ``"id"`` or ``"hash"`` are used as the same node may be part of multiple trajectories. + Defaults to ``"hash"``. - """ + Defaults to an empty string, which may imply a default behavior. - node: MCTSNode - index: torch.Tensor | None = None - hash: torch.Tensor | None = None + Returns: + Dict[int | Tuple[int], Tree]: A dictionary mapping keys to Tree vertices. + + """ + memo = set() + result = {} + q = deque() + cur_path = () + q.append((self, cur_path)) + use_hash = key_type == "hash" + use_id = key_type == "id" + use_path = key_type == "path" + while len(q): + tree, cur_path = q.popleft() + h = tree.hash + if h in memo and not use_path: + continue + memo.add(h) + r = tree.rollout + if r is not None: + r = r["next", "observation"] + if use_path: + result[cur_path] = tree + elif use_id: + result[tree.node_id] = tree + elif use_hash: + result[tree.node_id] = tree + else: + raise ValueError( + f"key_type must be either 'hash', 'id' or 'path'. Got {key_type}." + ) + + n = int(tree.num_children) + for i in range(n): + cur_path_tree = cur_path + (i,) + q.append((tree.subtree[i], cur_path_tree)) + return result + + def num_vertices(self, *, count_repeat: bool = False) -> int: + """Returns the number of unique vertices in the Tree. + + Keyword Args: + count_repeat (bool, optional): Determines whether to count repeated vertices. + - If ``False``, counts each unique vertex only once. + - If ``True``, counts vertices multiple times if they appear in different paths. + Defaults to ``False``. + + Returns: + int: The number of unique vertices in the Tree. + + """ + return len( + { + v.node_id + for v in self.vertices( + key_type="hash" if not count_repeat else "path" + ).values() + } + ) + + def edges(self) -> List[Tuple[int, int]]: + result = [] + q = deque() + parent = self.node_id + q.append((self, parent)) + while len(q): + tree, parent = q.popleft() + n = int(tree.num_children) + for i in range(n): + node = tree.subtree[i] + node_id = node.node_id + result.append((parent, node_id)) + q.append((node, node_id)) + return result + + def valid_paths(self): + q = deque() + cur_path = () + q.append((self, cur_path)) + while len(q): + tree, cur_path = q.popleft() + n = int(tree.num_children) + if not n: + yield cur_path + for i in range(n): + cur_path_tree = cur_path + (i,) + q.append((tree.subtree[i], cur_path_tree)) + + def max_length(self): + return max(*(len(path) for path in self.valid_paths())) + + def rollout_from_path(self, path: Tuple[int]) -> TensorDictBase | None: + r = self.rollout + tree = self + rollouts = [] + if r is not None: + rollouts.append(r) + for i in path: + tree = tree.subtree[i] + r = tree.rollout + if r is not None: + rollouts.append(r) + if rollouts: + return torch.cat(rollouts, dim=-1) + + @staticmethod + def _label(info: List[str], tree: "Tree", root=False): + labels = [] + for key in info: + if key == "hash": + hash = tree.hash + if hash is not None: + hash = hash.item() + v = f"hash={hash}" + elif root: + v = f"{key}=None" + else: + v = f"{key}={tree.rollout[key].mean().item()}" + + labels.append(v) + return ", ".join(labels) + + def plot( + self: Tree, + backend: str = "plotly", + figure: str = "tree", + info: List[str] = None, + make_labels: Callable[[Any], Any] | None = None, + ): + if backend == "plotly": + if figure == "box": + _plot_plotly_box(self) + return + elif figure == "tree": + _plot_plotly_tree(self, make_labels=make_labels) + return + else: + pass + raise NotImplementedError( + f"Unkown plotting backend {backend} with figure {figure}." + ) class MCTSForest: @@ -77,6 +311,8 @@ class MCTSForest: observation_keys (list of NestedKey): the observation keys of the environment. If not provided, defaults to ``("observation",)``. The :meth:`~.get_keys_from_env` can be used to automatically determine the keys. + consolidated (bool, optional): if ``True``, the data_map storage will be consolidated on disk. + Defaults to ``False``. Examples: >>> from torchrl.envs import GymEnv @@ -110,49 +346,58 @@ class MCTSForest: >>> rollout0 = env.rollout(6, auto_reset=False, tensordict=reset_state.clone()) >>> # Append the rollout to the forest. We're removing the state entries for clarity >>> rollout0 = rollout0.copy() - >>> rollout0.exclude(*state_keys, inplace=True) - >>> rollout0.get("next").exclude(*state_keys, inplace=True) + >>> rollout0.exclude(*state_keys, inplace=True).get("next").exclude(*state_keys, inplace=True) >>> forest.extend(rollout0) >>> # The forest should have 6 elements (the length of the rollout) >>> assert len(forest) == 6 >>> # Let's make another rollout from the same reset state >>> rollout1 = env.rollout(6, auto_reset=False, tensordict=reset_state.clone()) - >>> rollout1.exclude(*state_keys, inplace=True) - >>> rollout1.get("next").exclude(*state_keys, inplace=True) + >>> rollout1.exclude(*state_keys, inplace=True).get("next").exclude(*state_keys, inplace=True) >>> forest.extend(rollout1) >>> assert len(forest) == 12 - >>> # Since we have 2 rollouts starting at the same state, our tree should have two - >>> # branches if we produce it from the reset entry. Take the sate, and call `get_tree`: - >>> r = rollout0[0] - >>> tree = forest.get_tree(r) - >>> tree.shape - (torch.Size([]), _StringKeys(dict_keys(['data_content', 'children', 'count']))) - >>> tree.count - tensor(2, dtype=torch.int32) - >>> tree.children.shape - (torch.Size([2]), - _StringKeys(dict_keys(['node', 'action', 'reward', 'index', 'hash']))) - >>> # Now, let's run a rollout from an intermediate step from the second rollout - >>> starttd = rollout1[2] - >>> rollout2 = env.rollout(6, auto_reset=False, tensordict=starttd) - >>> rollout2.exclude(*state_keys, inplace=True) - >>> rollout2.get("next").exclude(*state_keys, inplace=True) - >>> forest.extend(rollout2) + >>> # Let's make another final rollout from an intermediate step in the second rollout + >>> rollout1b = env.rollout(6, auto_reset=False, tensordict=rollout1[3].exclude("next")) + >>> rollout1b.exclude(*state_keys, inplace=True) + >>> rollout1b.get("next").exclude(*state_keys, inplace=True) + >>> forest.extend(rollout1b) >>> assert len(forest) == 18 - >>> # Now our tree is a bit more complex, since from the 3rd level we have 3 different branches - >>> # When looking at the third level, we should see that the count for the left branch is 1, whereas the count - >>> # for the right branch is 2. + >>> # Since we have 2 rollouts starting at the same state, our tree should have two + >>> # branches if we produce it from the reset entry. Take the state, and call `get_tree`: >>> r = rollout0[0] + >>> # Let's get the compact tree that follows the initial reset. A compact tree is + >>> # a tree where nodes that have a single child are collapsed. >>> tree = forest.get_tree(r) - >>> print(tree.children.node.children.node.count, tree.children.node.children.node.shape) - tensor([[1], - [2]], dtype=torch.int32) torch.Size([2, 1]) - >>> # This is reflected by the shape of the children: the left has shape [1, 1], the right [1, 2] - >>> # and tensordict represents this as shape [2, 1, -1] - >>> tree.children.node.children.node.children.shape, tree.children.node.children.node.children[0].shape, tree.children.node.children.node.children[1].shape - (torch.Size([2, 1, -1]), torch.Size([1, 1]), torch.Size([1, 2])) - >>> tree.plot(info=["step_count", "reward"]) - + >>> print(tree.max_length()) + 2 + >>> print(list(tree.valid_paths())) + [(0,), (1, 0), (1, 1)] + >>> from tensordict import assert_close + >>> # We can manually rebuild the tree + >>> assert_close( + ... rollout1, + ... torch.cat([tree.subtree[1].rollout, tree.subtree[1].subtree[0].rollout]), + ... intersection=True, + ... ) + True + >>> # Or we can rebuild it using the dedicated method + >>> assert_close( + ... rollout1, + ... tree.rollout_from_path((1, 0)), + ... intersection=True, + ... ) + True + >>> tree.plot() + >>> tree = forest.get_tree(r, compact=False) + >>> print(tree.max_length()) + 9 + >>> print(list(tree.valid_paths())) + [(0, 0, 0, 0, 0, 0), (1, 0, 0, 0, 0, 0), (1, 0, 0, 1, 0, 0, 0, 0, 0)] + >>> assert_close( + ... rollout1, + ... tree.rollout_from_path((1, 0, 0, 0, 0, 0)), + ... intersection=True, + ... ) + True """ def __init__( @@ -164,6 +409,7 @@ def __init__( reward_keys: List[NestedKey] = None, observation_keys: List[NestedKey] = None, action_keys: List[NestedKey] = None, + consolidated: bool | None = None, ): self.data_map = data_map @@ -174,6 +420,7 @@ def __init__( self.action_keys = action_keys self.reward_keys = reward_keys self.observation_keys = observation_keys + self.consolidated = consolidated @property def done_keys(self): @@ -236,7 +483,7 @@ def get_keys_from_env(self, env: EnvBase): @classmethod def _write_fn_stack(cls, new, old=None): if old is None: - result = new.apply(lambda x: x.unsqueeze(0)) + result = new.apply(lambda x: x.unsqueeze(0), filter_empty=False) result.set( "count", torch.ones(result.shape, dtype=torch.int, device=result.device) ) @@ -255,11 +502,17 @@ def cat(name, x, y): return result def _make_storage(self, source, dest): - self.data_map = TensorDictMap.from_tensordict_pair( - source, - dest, - in_keys=[*self.observation_keys, *self.action_keys], - ) + try: + self.data_map = TensorDictMap.from_tensordict_pair( + source, + dest, + in_keys=[*self.observation_keys, *self.action_keys], + consolidated=self.consolidated, + ) + except KeyError as err: + raise KeyError( + "A KeyError occurred during data map creation. This could be due to the wrong setting of a key in the MCTSForest constructor. Scroll up for more info." + ) from err def _make_storage_branches(self, source, dest): self.node_map = TensorDictMap.from_tensordict_pair( @@ -267,7 +520,7 @@ def _make_storage_branches(self, source, dest): dest, in_keys=[*self.observation_keys], out_keys=[ - *self.data_map.query_module.out_keys, # hash and index + *self.data_map.query_module.out_keys, # hash and index # *self.action_keys, # *[("next", rk) for rk in self.reward_keys], "count", @@ -278,7 +531,10 @@ def _make_storage_branches(self, source, dest): ) def extend(self, rollout): - source, dest = rollout.exclude("next").copy(), rollout.select("next", *self.action_keys).copy() + source, dest = ( + rollout.exclude("next").copy(), + rollout.select("next", *self.action_keys).copy(), + ) if self.data_map is None: self._make_storage(source, dest) @@ -293,311 +549,122 @@ def extend(self, rollout): self._make_storage_branches(source, dest) self.node_map[source] = TensorDict.lazy_stack(value.unbind(0)) - def get_child(self, root): + def get_child(self, root: TensorDictBase) -> TensorDictBase: return self.data_map[root] - def get_tree( + def _make_local_tree( self, - root, - *, - inplace: bool = False, - recurse: bool = True, - max_depth: int | None = None, - as_tensordict: bool = False, - compact: bool=False, - ): - if root.batch_size: - if compact: - raise NotImplementedError - func = self._get_tree_batched + root: TensorDictBase, + index: torch.Tensor | None = None, + compact: bool = True, + ) -> Tuple[Tree, torch.Tensor | None, torch.Tensor | None]: + root = root.select(*self.node_map.in_keys) + node_meta = None + if root in self.node_map: + node_meta = self.node_map[root] + if index is None: + node_meta = self.node_map[root] + index = node_meta["_index"] + elif index is not None: + pass else: - if compact: - func = self._get_tree_single_compact + return None + steps = [] + while index.numel() <= 1: + index = index.squeeze() + d = self.data_map.storage[index] + steps.append(merge_tensordicts(d, root, callback_exist=lambda *x: None)) + d = d["next"] + if d in self.node_map: + root = d.select(*self.node_map.in_keys) + node_meta = self.node_map[root] + index = node_meta["_index"] + if not compact: + break else: - func = self._get_tree_single - return func( - root=root, - inplace=inplace, - recurse=recurse, - max_depth=max_depth, - as_tensordict=as_tensordict, + index = None + break + rollout = None + if steps: + rollout = torch.stack(steps, -1) + # Will be populated later + hash = node_meta["_hash"] + return ( + Tree( + rollout=rollout, + count=node_meta["count"], + node=root, + index=index, + hash=None, + subtree=None, + ), + index, + hash, ) - def _get_tree_single( - self, - root, - inplace: bool = False, - recurse: bool = True, - max_depth: int | None = None, - as_tensordict: bool = False, + # The recursive implementation is slower and less compatible with compile + # def _make_tree(self, root: TensorDictBase, index: torch.Tensor|None=None)->Tree: + # tree, indices = self._make_local_tree(root, index=index) + # subtrees = [] + # if indices is not None: + # for i in indices: + # subtree = self._make_tree(tree.node, index=i) + # subtrees.append(subtree) + # subtrees = TensorDict.lazy_stack(subtrees) + # tree.subtree = subtrees + # return tree + def _make_tree_iter( + self, root, index=None, max_depth: int | None = None, compact: bool = True ): - if root not in self.node_map: - if as_tensordict: - return TensorDict({"data_content": root}) - return MCTSNode(root, forest=weakref.ref(self)) - - branches = self.node_map[root] - - index = branches["_index"] - hash_val = branches["_hash"] - count = branches["count"] - - children_node = self.data_map.storage[index] - if not inplace: - root = root.copy() - if recurse: - children_node = children_node.unbind(0) - children_node = tuple( - self.get_tree( - child, - inplace=inplace, - max_depth=max_depth - 1 if isinstance(max_depth, int) else None, - ) - for child in children_node - ) - if not as_tensordict: - children_node = LazyStackedTensorDict( - *(child._tensordict for child in children_node) - ) - children_node = MCTSNode.from_tensordict(children_node) - children_node.forest = weakref.ref(self) - else: - children_node = LazyStackedTensorDict(*children_node) - - if not as_tensordict: - return MCTSNode( - data_content=root, - children=MCTSChildren( - node=children_node, - index=index, - hash=hash_val, - batch_size=children_node.batch_size, - ), - count=count, - forest=weakref.ref(self), - ) - return TensorDict( - { - "data_content": root, - "children": TensorDict( - { - "node": children_node, - "index": index, - "hash": hash_val, - }, - batch_sizde=children_node.batch_size, - ), - "count": count, - } - ) + q = deque() + memo = {} + tree, indices, hash = self._make_local_tree(root, index=index) + tree.node_id = 0 + + result = tree + depth = 0 + counter = 1 + if indices is not None: + q.append((tree, indices, hash, depth)) + del tree, indices + + while len(q): + tree, indices, hash, depth = q.popleft() + extend = max_depth is None or depth < max_depth + subtrees = [] + for i, h in zip(indices, hash): + # TODO: remove the .item() + h = h.item() + subtree, subtree_indices, subtree_hash = memo.get(h, (None,) * 3) + if subtree is None: + subtree, subtree_indices, subtree_hash = self._make_local_tree( + tree.node, index=i, compact=compact + ) + subtree.node_id = counter + counter += 1 + subtree.hash = h + memo[h] = (subtree, subtree_indices, subtree_hash) + + subtrees.append(subtree) + if extend and subtree_indices is not None: + q.append((subtree, subtree_indices, subtree_hash, depth + 1)) + subtrees = TensorDict.lazy_stack(subtrees) + tree.subtree = subtrees - def _get_tree_single_compact( - self, - root, - inplace: bool = False, - recurse: bool = True, - max_depth: int | None = None, - as_tensordict: bool = False, - ): - if root not in self.node_map: - if as_tensordict: - return TensorDict({"data_content": root}) - return MCTSNode(root, forest=weakref.ref(self)) - - stack = [] - child = root.select(*self.observation_keys).copy() - while True: - if child not in self.node_map: - # we already know this doesn't happen during the first iter - break - branch = self.node_map[child] - index = branch["_index"] - if index.numel() == 1: - # child contains (action, next) so we can update the previous data with this - new_child = self.data_map.storage[index[0]] - child.update(new_child) - stack.append(child) - child = new_child.get("next").select(*self.observation_keys) - else: - break - if len(stack): - root = torch.stack(stack, -1) - elif not inplace: - root = child - - hash_val = branch["_hash"] - count = branch["count"] - - children_node = self.data_map.storage[index].get("next").select(*self.observation_keys) - if recurse: - children_node = children_node.unbind(0) - children_node = tuple( - self.get_tree( - child, - inplace=inplace, - max_depth=max_depth - 1 if isinstance(max_depth, int) else None, - compact=True - ) - for child in children_node - ) - if not as_tensordict: - children_node = LazyStackedTensorDict( - *(child._tensordict for child in children_node) - ) - children_node = MCTSNode.from_tensordict(children_node) - children_node.forest = weakref.ref(self) - else: - children_node = LazyStackedTensorDict(*children_node) - - if not as_tensordict: - return MCTSNode( - data_content=root, - children=MCTSChildren( - node=children_node, - index=index, - hash=hash_val, - batch_size=children_node.batch_size, - ), - count=count, - forest=weakref.ref(self), - ) - return TensorDict( - { - "data_content": root, - "children": TensorDict( - { - "node": children_node, - "index": index, - "hash": hash_val, - }, - batch_sizde=children_node.batch_size, - ), - "count": count, - } - ) + return result - def _get_tree_batched( + def get_tree( self, root, - inplace: bool = False, - recurse: bool = True, + *, max_depth: int | None = None, - as_tensordict: bool = False, - ): - present = self.node_map.contains(root) - if not present.any(): - if as_tensordict: - return TensorDict({"data_content": root}, batch_size=root.batch_size) - return MCTSNode(root, forest=weakref.ref(self), batch_size=root.batch_size) - if present.all(): - root_present = root - else: - root_present = root[present] - branches = self.node_map[root_present] - index = branches.get_nestedtensor("_index", layout=torch.jagged) - hash_val = branches.get_nestedtensor("_hash", layout=torch.jagged) - count = branches.get("count") - - children_node = self.data_map.storage[index.values()] - if not root_present.all(): - children_node = LazyStackedTensorDict( - *children_node.split(index.offsets().diff().tolist()) - ) - for idx in (~present).nonzero(as_tuple=True)[0].tolist(): - children_node.insert(idx, TensorDict()) # TODO: replace with new_zero - - if not inplace: - root = root.copy() - if recurse: - children_node = children_node.unbind(0) - children_node = tuple( - self.get_tree( - child, - inplace=inplace, - max_depth=max_depth - 1 if isinstance(max_depth, int) else None, - ) - if present[i] - else child - for i, child in enumerate(children_node) - ) - children = TensorDict.lazy_stack( - [ - TensorDict( - { - "node": _children_node, - "index": _index, - "hash": _hash_val, - }, - batch_size=_children_node.batch_size, - ) - for (_children_node, _action, _index, _hash_val, _reward) in zip( - children_node, - action.unbind(0), - index.unbind(0), - hash_val.unbind(0), - reward.unbind(0), - ) - ] - ) - if not as_tensordict: - return MCTSNode( - data_content=root, - children=MCTSChildren._from_tensordict(children), - count=count, - forest=weakref.ref(self), - batch_size=root.batch_size, - ) - return TensorDict( - { - "data_content": root, - "children": children, - "count": count, - }, - batch_size=root.batch_size, - ) + compact: bool = True, + ) -> Tree: + return self._make_tree_iter(root=root, max_depth=max_depth, compact=compact) + + @classmethod + def valid_paths(cls, tree: Tree): + yield from tree.valid_paths() def __len__(self): return len(self.data_map) - - @staticmethod - def _label(info, tree, root=False): - labels = [] - for key in info: - if key == "hash": - if not root: - v = f"hash={tree.hash.item()}" - else: - v = f"hash={tree.data_content['_hash'].item()}" - elif root: - v = f"{key}=None" - else: - v = f"{key}={tree.node.data_content[key].item()}" - - labels.append(v) - return ", ".join(labels) - - def plot(self, tree, backend="plotly", info=None): - if backend == "plotly": - import plotly.graph_objects as go - if info is None: - info = ["hash", "reward"] - - parents = [""] - labels = [self._label(info, tree, root=True)] - - _tree = tree - - def extend(tree, parent): - children = tree.children - if children is None: - return - for child in children: - labels.append(self._label(info, child)) - parents.append(parent) - extend(child.node, labels[-1]) - - extend(_tree, labels[-1]) - fig = go.Figure(go.Treemap(labels=labels, parents=parents)) - fig.show() - else: - raise NotImplementedError(f"Unkown plotting backend {backend}.") diff --git a/torchrl/data/map/utils.py b/torchrl/data/map/utils.py new file mode 100644 index 00000000000..9a54913ca2a --- /dev/null +++ b/torchrl/data/map/utils.py @@ -0,0 +1,102 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +from typing import Callable, List + +from tensordict import NestedKey + + +def _plot_plotly_tree( + tree: "Tree", make_labels: Callable[[Tree], str] | None = None # noqa: F821 +): + import plotly.graph_objects as go + from igraph import Graph + + if make_labels is None: + + def make_labels(tree): + return str((tree.node_id, tree.hash)) + + nr_vertices = tree.num_vertices() + vertices = tree.vertices() + + v_label = [make_labels(subtree) for subtree in vertices.values()] + G = Graph(nr_vertices, tree.edges()) + + layout = G.layout_sugiyama(range(nr_vertices)) + + position = {k: layout[k] for k in range(nr_vertices)} + # Y = [layout[k][1] for k in range(nr_vertices)] + # M = max(Y) + + # es = EdgeSeq(G) # sequence of edges + E = [e.tuple for e in G.es] # list of edges + + L = len(position) + Xn = [position[k][0] for k in range(L)] + # Yn = [2 * M - position[k][1] for k in range(L)] + Yn = [position[k][1] for k in range(L)] + Xe = [] + Ye = [] + for edge in E: + Xe += [position[edge[0]][0], position[edge[1]][0], None] + # Ye += [2 * M - position[edge[0]][1], 2 * M - position[edge[1]][1], None] + Ye += [position[edge[0]][1], position[edge[1]][1], None] + + labels = v_label + fig = go.Figure() + fig.add_trace( + go.Scatter( + x=Xe, + y=Ye, + mode="lines", + line={"color": "rgb(210,210,210)", "width": 1}, + hoverinfo="none", + ) + ) + fig.add_trace( + go.Scatter( + x=Xn, + y=Yn, + mode="markers", + name="bla", + marker={ + "symbol": "circle-dot", + "size": 18, + "color": "#6175c1", # '#DB4551', + "line": {"color": "rgb(50,50,50)", "width": 1}, + }, + text=labels, + hoverinfo="text", + opacity=0.8, + ) + ) + fig.show() + + +def _plot_plotly_box(tree: "Tree", info: List[NestedKey] = None): # noqa: F821 + import plotly.graph_objects as go + + if info is None: + info = ["hash", ("next", "reward")] + + parents = [""] + labels = [tree._label(info, tree, root=True)] + + _tree = tree + + def extend(tree: "Tree", parent): # noqa: F821 + children = tree.subtree + if children is None: + return + for child in children: + labels.append(tree._label(info, child)) + parents.append(parent) + extend(child, labels[-1]) + + extend(_tree, labels[-1]) + fig = go.Figure(go.Treemap(labels=labels, parents=parents)) + fig.show() diff --git a/torchrl/data/replay_buffers/replay_buffers.py b/torchrl/data/replay_buffers/replay_buffers.py index 2672c90092f..5e7b80d7bed 100644 --- a/torchrl/data/replay_buffers/replay_buffers.py +++ b/torchrl/data/replay_buffers/replay_buffers.py @@ -19,6 +19,11 @@ import torch +try: + from torch.compiler import is_dynamo_compiling +except ImportError: + from torch._dynamo import is_compiling as is_dynamo_compiling + from tensordict import ( is_tensor_collection, is_tensorclass, @@ -132,6 +137,9 @@ class ReplayBuffer: .. warning:: As of now, the generator has no effect on the transforms. shared (bool, optional): whether the buffer will be shared using multiprocessing or not. Defaults to ``False``. + compilable (bool, optional): whether the writer is compilable. + If ``True``, the writer cannot be shared between multiple processes. + Defaults to ``False``. Examples: >>> import torch @@ -217,11 +225,20 @@ def __init__( checkpointer: "StorageCheckpointerBase" | None = None, # noqa: F821 generator: torch.Generator | None = None, shared: bool = False, + compilable: bool = None, ) -> None: - self._storage = storage if storage is not None else ListStorage(max_size=1_000) + self._storage = ( + storage + if storage is not None + else ListStorage(max_size=1_000, compilable=compilable) + ) self._storage.attach(self) self._sampler = sampler if sampler is not None else RandomSampler() - self._writer = writer if writer is not None else RoundRobinWriter() + self._writer = ( + writer + if writer is not None + else RoundRobinWriter(compilable=bool(compilable)) + ) self._writer.register_storage(self._storage) self._get_collate_fn(collate_fn) @@ -600,7 +617,9 @@ def _add(self, data): return index def _extend(self, data: Sequence) -> torch.Tensor: - with self._replay_lock, self._write_lock: + is_compiling = is_dynamo_compiling() + nc = contextlib.nullcontext() + with self._replay_lock if not is_compiling else nc, self._write_lock if not is_compiling else nc: if self.dim_extend > 0: data = self._transpose(data) index = self._writer.extend(data) @@ -653,7 +672,7 @@ def update_priority( @pin_memory_output def _sample(self, batch_size: int) -> Tuple[Any, dict]: - with self._replay_lock: + with self._replay_lock if not is_dynamo_compiling() else contextlib.nullcontext(): index, info = self._sampler.sample(self._storage, batch_size) info["index"] = index data = self._storage.get(index) diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index 21cbfce7b31..f029099eb7c 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -61,10 +61,15 @@ class Storage: _rng: torch.Generator | None = None def __init__( - self, max_size: int, checkpointer: StorageCheckpointerBase | None = None + self, + max_size: int, + checkpointer: StorageCheckpointerBase | None = None, + compilable: bool = False, ) -> None: self.max_size = int(max_size) self.checkpointer = checkpointer + self._compilable = compilable + self._attached_entities_set = set() @property def checkpointer(self): @@ -84,11 +89,14 @@ def _is_full(self): def _attached_entities(self): # RBs that use a given instance of Storage should add # themselves to this set. - _attached_entities = self.__dict__.get("_attached_entities_set", None) - if _attached_entities is None: - _attached_entities = set() - self.__dict__["_attached_entities_set"] = _attached_entities - return _attached_entities + _attached_entities_set = getattr(self, "_attached_entities_set", None) + if _attached_entities_set is None: + self._attached_entities_set = _attached_entities_set = set() + return _attached_entities_set + + @torch._dynamo.assume_constant_result + def _attached_entities_iter(self): + return list(self._attached_entities) @abc.abstractmethod def set(self, cursor: int, data: Any, *, set_cursor: bool = True): @@ -144,29 +152,12 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: def _empty(self): ... - # NOTE: This property is used to enable compiled Storages. Calling - # `len(self)` on a TensorStorage should normally cause a graph break since - # it uses a `mp.Value`, and it does cause a break when the `len(self)` call - # happens within a method of TensorStorage itself. However, when the - # `len(self)` call happens in the Storage base class, for an unknown reason - # the compiler doesn't seem to recognize that there should be a graph break, - # and the lack of a break causes a recompile each time `len(self)` is called - # in this context. Also for an unknown reason, we can force the graph break - # to happen if we wrap the `len(self)` call with a `property`-decorated - # function. For another unknown reason, if we change - # `TensorStorage._len_value` from `mp.Value` to int, it seems like there - # should no longer be any need to recompile, but recompiles happen anyway. - # Ideally, this should all be investigated and understood in the future. - @property - def len(self): - return len(self) - def _rand_given_ndim(self, batch_size): # a method to return random indices given the storage ndim if self.ndim == 1: return torch.randint( 0, - self.len, + len(self), (batch_size,), generator=self._rng, device=getattr(self, "device", None), @@ -241,10 +232,10 @@ class ListStorage(Storage): _default_checkpointer = ListStorageCheckpointer - def __init__(self, max_size: int | None = None): + def __init__(self, max_size: int | None = None, compilable: bool = False): if max_size is None: max_size = torch.iinfo(torch.int64).max - super().__init__(max_size) + super().__init__(max_size, compilable=compilable) self._storage = [] def set( @@ -381,6 +372,9 @@ class TensorStorage(Storage): measuring the storage size. For instance, a storage of shape ``[3, 4]`` has capacity ``3`` if ``ndim=1`` and ``12`` if ``ndim=2``. Defaults to ``1``. + compilable (bool, optional): whether the storage is compilable. + If ``True``, the writer cannot be shared between multiple processes. + Defaults to ``False``. Examples: >>> data = TensorDict({ @@ -440,6 +434,7 @@ def __init__( *, device: torch.device = "cpu", ndim: int = 1, + compilable: bool = False, ): if not ((storage is None) ^ (max_size is None)): if storage is None: @@ -455,7 +450,7 @@ def __init__( else: max_size = tree_flatten(storage)[0][0].shape[0] self.ndim = ndim - super().__init__(max_size) + super().__init__(max_size, compilable=compilable) self.initialized = storage is not None if self.initialized: self._len = max_size @@ -474,16 +469,24 @@ def __init__( @property def _len(self): _len_value = self.__dict__.get("_len_value", None) - if _len_value is None: - _len_value = self._len_value = mp.Value("i", 0) - return _len_value.value + if not self._compilable: + if _len_value is None: + _len_value = self._len_value = mp.Value("i", 0) + return _len_value.value + else: + if _len_value is None: + _len_value = self._len_value = 0 + return _len_value @_len.setter def _len(self, value): - _len_value = self.__dict__.get("_len_value", None) - if _len_value is None: - _len_value = self._len_value = mp.Value("i", 0) - _len_value.value = value + if not self._compilable: + _len_value = self.__dict__.get("_len_value", None) + if _len_value is None: + _len_value = self._len_value = mp.Value("i", 0) + _len_value.value = value + else: + self._len_value = value @property def _total_shape(self): @@ -550,7 +553,16 @@ def shape(self): if _total_shape is not None: return torch.Size([self._len_along_dim0] + list(_total_shape[1:])) + # TODO: Without this disable, compiler recompiles for back-to-back calls. + # Figuring out a way to avoid this disable would give better performance. + @torch._dynamo.disable() def _rand_given_ndim(self, batch_size): + return self._rand_given_ndim_impl(batch_size) + + # At the moment, this is separated into its own function so that we can test + # it without the `torch._dynamo.disable` and detect if future updates to the + # compiler fix the recompile issue. + def _rand_given_ndim_impl(self, batch_size): if self.ndim == 1: return super()._rand_given_ndim(batch_size) shape = self.shape @@ -623,8 +635,11 @@ def assert_is_sharable(tensor): def __setstate__(self, state): len = state.pop("len__context", None) if len is not None: - _len_value = mp.Value("i", len) - state["_len_value"] = _len_value + if not state["_compilable"]: + state["_len_value"] = len + else: + _len_value = mp.Value("i", len) + state["_len_value"] = _len_value self.__dict__.update(state) def state_dict(self) -> Dict[str, Any]: @@ -674,7 +689,7 @@ def load_state_dict(self, state_dict): self.initialized = state_dict["initialized"] self._len = state_dict["_len"] - @implement_for("torch", "2.3") + @implement_for("torch", "2.3", compilable=True) def _set_tree_map(self, cursor, data, storage): def set_tensor(datum, store): store[cursor] = datum @@ -682,7 +697,7 @@ def set_tensor(datum, store): # this won't be available until v2.3 tree_map(set_tensor, data, storage) - @implement_for("torch", "2.0", "2.3") + @implement_for("torch", "2.0", "2.3", compilable=True) def _set_tree_map(self, cursor, data, storage): # noqa: 534 # flatten data and cursor data_flat = tree_flatten(data)[0] @@ -700,7 +715,7 @@ def _get_new_len(self, data, cursor): numel = leaf.shape[:ndim].numel() self._len = min(self._len + numel, self.max_size) - @implement_for("torch", "2.0", None) + @implement_for("torch", "2.0", None, compilable=True) def set( self, cursor: Union[int, Sequence[int], slice], @@ -742,7 +757,7 @@ def set( else: self._set_tree_map(cursor, data, self._storage) - @implement_for("torch", None, "2.0") + @implement_for("torch", None, "2.0", compilable=True) def set( # noqa: F811 self, cursor: Union[int, Sequence[int], slice], @@ -893,6 +908,11 @@ class LazyTensorStorage(TensorStorage): measuring the storage size. For instance, a storage of shape ``[3, 4]`` has capacity ``3`` if ``ndim=1`` and ``12`` if ``ndim=2``. Defaults to ``1``. + compilable (bool, optional): whether the storage is compilable. + If ``True``, the writer cannot be shared between multiple processes. + Defaults to ``False``. + consolidated (bool, optional): if ``True``, the storage will be consolidated after + its first expansion. Defaults to ``False``. Examples: >>> data = TensorDict({ @@ -952,14 +972,26 @@ def __init__( *, device: torch.device = "cpu", ndim: int = 1, + compilable: bool = False, + consolidated: bool = False, ): - super().__init__(storage=None, max_size=max_size, device=device, ndim=ndim) + super().__init__( + storage=None, + max_size=max_size, + device=device, + ndim=ndim, + compilable=compilable, + ) + self.consolidated = consolidated def _init( self, data: Union[TensorDictBase, torch.Tensor, "PyTree"], # noqa: F821 ) -> None: - torchrl_logger.debug("Creating a TensorStorage...") + if not self._compilable: + # TODO: Investigate why this seems to have a performance impact with + # the compiler + torchrl_logger.debug("Creating a TensorStorage...") if self.device == "auto": self.device = data.device @@ -975,7 +1007,11 @@ def max_size_along_dim0(data_shape): if is_tensor_collection(data): out = data.to(self.device) - out = torch.empty_like(out.expand(max_size_along_dim0(data.shape))) + out: TensorDictBase = torch.empty_like( + out.expand(max_size_along_dim0(data.shape)) + ) + if self.consolidated: + out = out.consolidate() else: # if Tensor, we just create a MemoryMappedTensor of the desired shape, device and dtype out = tree_map( @@ -986,6 +1022,8 @@ def max_size_along_dim0(data_shape): ), data, ) + if self.consolidated: + raise ValueError("Cannot consolidate non-tensordict storages.") self._storage = out self.initialized = True @@ -1089,8 +1127,9 @@ def __init__( device: torch.device = "cpu", ndim: int = 1, existsok: bool = False, + compilable: bool = False, ): - super().__init__(max_size, ndim=ndim) + super().__init__(max_size, ndim=ndim, compilable=compilable) self.initialized = False self.scratch_dir = None self.existsok = existsok @@ -1264,10 +1303,6 @@ def _rng(self, value): for storage in self._storages: storage._rng = value - @property - def _attached_entities(self): - return set() - def extend(self, value): raise RuntimeError diff --git a/torchrl/data/replay_buffers/writers.py b/torchrl/data/replay_buffers/writers.py index 3a95c3975cc..7fb865453d6 100644 --- a/torchrl/data/replay_buffers/writers.py +++ b/torchrl/data/replay_buffers/writers.py @@ -40,8 +40,9 @@ class Writer(ABC): _storage: Storage _rng: torch.Generator | None = None - def __init__(self) -> None: + def __init__(self, compilable: bool = False) -> None: self._storage = None + self._compilable = compilable def register_storage(self, storage: Storage) -> None: self._storage = storage @@ -138,10 +139,17 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: class RoundRobinWriter(Writer): - """A RoundRobin Writer class for composable replay buffers.""" + """A RoundRobin Writer class for composable replay buffers. - def __init__(self, **kw) -> None: - super().__init__(**kw) + Args: + compilable (bool, optional): whether the writer is compilable. + If ``True``, the writer cannot be shared between multiple processes. + Defaults to ``False``. + + """ + + def __init__(self, compilable: bool = False) -> None: + super().__init__(compilable=compilable) self._cursor = 0 def dumps(self, path): @@ -197,7 +205,7 @@ def extend(self, data: Sequence) -> torch.Tensor: # Other than that, a "flat" (1d) index is ok to write the data self._storage.set(index, data) index = self._replicate_index(index) - for ent in self._storage._attached_entities: + for ent in self._storage._attached_entities_iter(): ent.mark_update(index) return index @@ -213,30 +221,46 @@ def _empty(self): @property def _cursor(self): _cursor_value = self.__dict__.get("_cursor_value", None) - if _cursor_value is None: - _cursor_value = self._cursor_value = mp.Value("i", 0) - return _cursor_value.value + if not self._compilable: + if _cursor_value is None: + _cursor_value = self._cursor_value = mp.Value("i", 0) + return _cursor_value.value + else: + if _cursor_value is None: + _cursor_value = self._cursor_value = 0 + return _cursor_value @_cursor.setter def _cursor(self, value): - _cursor_value = self.__dict__.get("_cursor_value", None) - if _cursor_value is None: - _cursor_value = self._cursor_value = mp.Value("i", 0) - _cursor_value.value = value + if not self._compilable: + _cursor_value = self.__dict__.get("_cursor_value", None) + if _cursor_value is None: + _cursor_value = self._cursor_value = mp.Value("i", 0) + _cursor_value.value = value + else: + self._cursor_value = value @property def _write_count(self): _write_count = self.__dict__.get("_write_count_value", None) - if _write_count is None: - _write_count = self._write_count_value = mp.Value("i", 0) - return _write_count.value + if not self._compilable: + if _write_count is None: + _write_count = self._write_count_value = mp.Value("i", 0) + return _write_count.value + else: + if _write_count is None: + _write_count = self._write_count_value = 0 + return _write_count @_write_count.setter def _write_count(self, value): - _write_count = self.__dict__.get("_write_count_value", None) - if _write_count is None: - _write_count = self._write_count_value = mp.Value("i", 0) - _write_count.value = value + if not self._compilable: + _write_count = self.__dict__.get("_write_count_value", None) + if _write_count is None: + _write_count = self._write_count_value = mp.Value("i", 0) + _write_count.value = value + else: + self._write_count_value = value def __getstate__(self): state = super().__getstate__() @@ -249,7 +273,10 @@ def __getstate__(self): def __setstate__(self, state): cursor = state.pop("cursor__context", None) if cursor is not None: - _cursor_value = mp.Value("i", cursor) + if not state["_compilable"]: + _cursor_value = mp.Value("i", cursor) + else: + _cursor_value = cursor state["_cursor_value"] = _cursor_value self.__dict__.update(state) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 600e03775f7..83233aaaac0 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -4596,10 +4596,11 @@ class TensorDictPrimer(Transform): The corresponding value has to be a TensorSpec instance indicating what the value must be. - When used in a TransfomedEnv, the spec shapes must match the envs shape if - the parent env is batch-locked (:obj:`env.batch_locked=True`). - If the env is not batch-locked (e.g. model-based envs), it is assumed that the batch is - given by the input tensordict instead. + When used in a `TransformedEnv`, the spec shapes must match the environment's shape if + the parent environment is batch-locked (`env.batch_locked=True`). If the spec shapes and + parent shapes do not match, the spec shapes are modified in-place to match the leading + dimensions of the parent's batch size. This adjustment is made for cases where the parent + batch size dimension is not known during instantiation. Examples: >>> from torchrl.envs.libs.gym import GymEnv @@ -4639,6 +4640,40 @@ class TensorDictPrimer(Transform): tensor([[1., 1., 1.], [1., 1., 1.]]) + Examples: + >>> from torchrl.envs.libs.gym import GymEnv + >>> from torchrl.envs import SerialEnv, TransformedEnv + >>> from torchrl.modules.utils import get_primers_from_module + >>> from torchrl.modules import GRUModule + >>> base_env = SerialEnv(2, lambda: GymEnv("Pendulum-v1")) + >>> env = TransformedEnv(base_env) + >>> model = GRUModule(input_size=2, hidden_size=2, in_key="observation", out_key="action") + >>> primers = get_primers_from_module(model) + >>> print(primers) # Primers shape is independent of the env batch size + TensorDictPrimer(primers=Composite( + recurrent_state: UnboundedContinuous( + shape=torch.Size([1, 2]), + space=ContinuousBox( + low=Tensor(shape=torch.Size([1, 2]), device=cpu, dtype=torch.float32, contiguous=True), + high=Tensor(shape=torch.Size([1, 2]), device=cpu, dtype=torch.float32, contiguous=True)), + device=cpu, + dtype=torch.float32, + domain=continuous), + device=None, + shape=torch.Size([])), default_value={'recurrent_state': 0.0}, random=None) + >>> env.append_transform(primers) + >>> print(env.reset()) # The primers are automatically expanded to match the env batch size + TensorDict( + fields={ + done: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False), + observation: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False), + recurrent_state: Tensor(shape=torch.Size([2, 1, 2]), device=cpu, dtype=torch.float32, is_shared=False), + terminated: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([2]), + device=None, + is_shared=False) + .. note:: Some TorchRL modules rely on specific keys being present in the environment TensorDicts, like :class:`~torchrl.modules.models.LSTM` or :class:`~torchrl.modules.models.GRU`. To facilitate this process, the method :func:`~torchrl.modules.utils.get_primers_from_module` @@ -4764,7 +4799,7 @@ def transform_observation_spec(self, observation_spec: Composite) -> Composite: # We try to set the primer shape to the observation spec shape self.primers.shape = observation_spec.shape except ValueError: - # If we fail, we expnad them to that shape + # If we fail, we expand them to that shape self.primers = self._expand_shape(self.primers) device = observation_spec.device observation_spec.update(self.primers.clone().to(device)) @@ -4831,12 +4866,17 @@ def _reset( ) -> TensorDictBase: """Sets the default values in the input tensordict. - If the parent is batch-locked, we assume that the specs have the appropriate leading + If the parent is batch-locked, we make sure the specs have the appropriate leading shape. We allow for execution when the parent is missing, in which case the spec shape is assumed to match the tensordict's. - """ _reset = _get_reset(self.reset_key, tensordict) + if ( + self.parent + and self.parent.batch_locked + and self.primers.shape[: len(self.parent.shape)] != self.parent.batch_size + ): + self.primers = self._expand_shape(self.primers) if _reset.any(): for key, spec in self.primers.items(True, True): if self.random: diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index efc951b3999..ef78bc4bb0f 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -463,6 +463,8 @@ def reset(self) -> None: def get_entropy_bonus(self, dist: d.Distribution) -> torch.Tensor: try: entropy = dist.entropy() + if is_tensor_collection(entropy): + entropy = entropy.get(dist.entropy_key) except NotImplementedError: x = dist.rsample((self.samples_mc_entropy,)) log_prob = dist.log_prob(x)