diff --git a/.github/unittest/linux_sota/scripts/test_sota.py b/.github/unittest/linux_sota/scripts/test_sota.py index b7af381634c..25d1e7a4390 100644 --- a/.github/unittest/linux_sota/scripts/test_sota.py +++ b/.github/unittest/linux_sota/scripts/test_sota.py @@ -188,19 +188,6 @@ ppo.collector.frames_per_batch=16 \ logger.mode=offline \ logger.backend= -""", - "dreamer": """python sota-implementations/dreamer/dreamer.py \ - collector.total_frames=600 \ - collector.init_random_frames=10 \ - collector.frames_per_batch=200 \ - env.n_parallel_envs=1 \ - optimization.optim_steps_per_batch=1 \ - logger.video=False \ - logger.backend=csv \ - replay_buffer.buffer_size=120 \ - replay_buffer.batch_size=24 \ - replay_buffer.batch_length=12 \ - networks.rssm_hidden_dim=17 """, "ddpg-single": """python sota-implementations/ddpg/ddpg.py \ collector.total_frames=48 \ @@ -289,6 +276,19 @@ logger.backend= """, "bandits": """python sota-implementations/bandits/dqn.py --n_steps=100 +""", + "dreamer": """python sota-implementations/dreamer/dreamer.py \ + collector.total_frames=600 \ + collector.init_random_frames=10 \ + collector.frames_per_batch=200 \ + env.n_parallel_envs=1 \ + optimization.optim_steps_per_batch=1 \ + logger.video=False \ + logger.backend=csv \ + replay_buffer.buffer_size=120 \ + replay_buffer.batch_size=24 \ + replay_buffer.batch_length=12 \ + networks.rssm_hidden_dim=17 """, } diff --git a/examples/agents/composite_actor.py b/examples/agents/composite_actor.py index ae08062e084..c7e83095983 100644 --- a/examples/agents/composite_actor.py +++ b/examples/agents/composite_actor.py @@ -50,3 +50,9 @@ def forward(self, x): data = TensorDict({"x": torch.rand(10)}, []) module(data) print(actor(data)) + + +# TODO: +# 1. Use ("action", "action0") + ("action", "action1") vs ("agent0", "action") + ("agent1", "action") +# 2. Must multi-head require an action_key to be a list of keys (I guess so) +# 3. Using maps in the Actor diff --git a/examples/agents/composite_ppo.py b/examples/agents/composite_ppo.py new file mode 100644 index 00000000000..501dceb651d --- /dev/null +++ b/examples/agents/composite_ppo.py @@ -0,0 +1,190 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +Multi-head Agent and PPO Loss +============================= +This example demonstrates how to use TorchRL to create a multi-head agent with three separate distributions +(Gamma, Kumaraswamy, and Mixture) and train it using Proximal Policy Optimization (PPO) losses. + +Step-by-step Explanation +------------------------ + +1. **Setting Composite Log-Probabilities**: + - To use composite (=multi-head0 distributions with PPO (or any other algorithm that relies on probability distributions like SAC + or A2C), you must call `set_composite_lp_aggregate(False).set()`. Not calling this will result in errors during + execution of your script. + - From torchrl and tensordict v0.9, this will be the default behavior. Not doing this will result in + `CompositeDistribution` aggregating the log-probs, which may lead to incorrect log-probabilities. + - Note that `set_composite_lp_aggregate(False).set()` will cause the sample log-probabilities to be named + `_log_prob` for any probability distribution, not just composite ones. For regular, single-head policies + for instance, the log-probability will be named `"action_log_prob"`. + Previously, log-prob keys defaulted to `sample_log_prob`. +2. **Action Grouping**: + - Actions can be grouped or not; PPO doesn't require them to be grouped. + - If actions are grouped, calling the policy will result in a `TensorDict` with fields for each agent's action and + log-probability, e.g., `agent0`, `agent0_log_prob`, etc. + + ... [...] + ... action: TensorDict( + ... fields={ + ... agent0: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False), + ... agent0_log_prob: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False), + ... agent1: Tensor(shape=torch.Size([4, 2]), device=cpu, dtype=torch.float32, is_shared=False), + ... agent1_log_prob: Tensor(shape=torch.Size([4, 2]), device=cpu, dtype=torch.float32, is_shared=False), + ... agent2: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False), + ... agent2_log_prob: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False)}, + ... batch_size=torch.Size([4]), + ... device=None, + ... is_shared=False), + + - If actions are not grouped, each agent will have its own `TensorDict` with `action` and `action_log_prob` fields. + + ... [...] + ... agent0: TensorDict( + ... fields={ + ... action: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False), + ... action_log_prob: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False)}, + ... batch_size=torch.Size([4]), + ... device=None, + ... is_shared=False), + ... agent1: TensorDict( + ... fields={ + ... action: Tensor(shape=torch.Size([4, 2]), device=cpu, dtype=torch.float32, is_shared=False), + ... action_log_prob: Tensor(shape=torch.Size([4, 2]), device=cpu, dtype=torch.float32, is_shared=False)}, + ... batch_size=torch.Size([4]), + ... device=None, + ... is_shared=False), + ... agent2: TensorDict( + ... fields={ + ... action: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False), + ... action_log_prob: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False)}, + ... batch_size=torch.Size([4]), + ... device=None, + ... is_shared=False), + +3. **PPO Loss Calculation**: + - Under the hood, `ClipPPO` will clip individual weights (not the aggregate) and multiply that by the advantage. + +The code below sets up a multi-head agent with three distributions and demonstrates how to train it using PPO losses. + +""" + +import functools + +import torch +from tensordict import TensorDict +from tensordict.nn import ( + CompositeDistribution, + InteractionType, + ProbabilisticTensorDictModule as Prob, + ProbabilisticTensorDictSequential as ProbSeq, + set_composite_lp_aggregate, + TensorDictModule as Mod, + TensorDictSequential as Seq, + WrapModule as Wrap, +) +from torch import distributions as d +from torchrl.objectives import ClipPPOLoss, KLPENPPOLoss, PPOLoss + +set_composite_lp_aggregate(False).set() + +GROUPED_ACTIONS = False + +make_params = Mod( + lambda: ( + torch.ones(4), + torch.ones(4), + torch.ones(4, 2), + torch.ones(4, 2), + torch.ones(4, 10) / 10, + torch.zeros(4, 10), + torch.ones(4, 10), + ), + in_keys=[], + out_keys=[ + ("params", "gamma", "concentration"), + ("params", "gamma", "rate"), + ("params", "Kumaraswamy", "concentration0"), + ("params", "Kumaraswamy", "concentration1"), + ("params", "mixture", "logits"), + ("params", "mixture", "loc"), + ("params", "mixture", "scale"), + ], +) + + +def mixture_constructor(logits, loc, scale): + return d.MixtureSameFamily( + d.Categorical(logits=logits), d.Normal(loc=loc, scale=scale) + ) + + +if GROUPED_ACTIONS: + name_map = { + "gamma": ("action", "agent0"), + "Kumaraswamy": ("action", "agent1"), + "mixture": ("action", "agent2"), + } +else: + name_map = { + "gamma": ("agent0", "action"), + "Kumaraswamy": ("agent1", "action"), + "mixture": ("agent2", "action"), + } + +dist_constructor = functools.partial( + CompositeDistribution, + distribution_map={ + "gamma": d.Gamma, + "Kumaraswamy": d.Kumaraswamy, + "mixture": mixture_constructor, + }, + name_map=name_map, +) + + +policy = ProbSeq( + make_params, + Prob( + in_keys=["params"], + out_keys=list(name_map.values()), + distribution_class=dist_constructor, + return_log_prob=True, + default_interaction_type=InteractionType.RANDOM, + ), +) + +td = policy(TensorDict(batch_size=[4])) +print("Result of policy call", td) + +dist = policy.get_dist(td) +log_prob = dist.log_prob(td) +print("Composite log-prob", log_prob) + +# Build a dummy value operator +value_operator = Seq( + Wrap( + lambda td: td.set("state_value", torch.ones((*td.shape, 1))), + out_keys=["state_value"], + ) +) + +# Create fake data +data = policy(TensorDict(batch_size=[4])) +data.set( + "next", + TensorDict(reward=torch.randn(4, 1), done=torch.zeros(4, 1, dtype=torch.bool)), +) + +# Instantiate the loss - test the 3 different PPO losses +for loss_cls in (PPOLoss, ClipPPOLoss, KLPENPPOLoss): + # PPO sets the keys automatically by looking at the policy + ppo = loss_cls(policy, value_operator) + print("tensor keys", ppo.tensor_keys) + + # Get the loss values + loss_vals = ppo(data) + print("Loss result:", loss_cls, loss_vals) diff --git a/test/test_cost.py b/test/test_cost.py index a0283e0e276..c8e45624580 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -8,7 +8,6 @@ import itertools import operator import os - import sys import warnings from copy import deepcopy @@ -22,6 +21,7 @@ from tensordict import assert_allclose_td, TensorDict, TensorDictBase from tensordict._C import unravel_keys from tensordict.nn import ( + composite_lp_aggregate, CompositeDistribution, InteractionType, NormalParamExtractor, @@ -29,11 +29,14 @@ ProbabilisticTensorDictModule as ProbMod, ProbabilisticTensorDictSequential, ProbabilisticTensorDictSequential as ProbSeq, + set_composite_lp_aggregate, TensorDictModule, TensorDictModule as Mod, TensorDictSequential, TensorDictSequential as Seq, + WrapModule, ) +from tensordict.nn.distributions.composite import _add_suffix from tensordict.nn.utils import Buffer from tensordict.utils import unravel_key from torch import autograd, nn @@ -197,6 +200,13 @@ def get_devices(): class LossModuleTestBase: + @pytest.fixture(scope="class", autouse=True) + def _composite_log_prob(self): + setter = set_composite_lp_aggregate(False) + setter.set() + yield + setter.unset() + def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) assert hasattr( @@ -3562,7 +3572,6 @@ def _create_mock_actor( distribution_map={ "action1": TanhNormal, }, - aggregate_probabilities=True, ) module_out_keys = [ ("params", "action1", "loc"), @@ -3582,6 +3591,7 @@ def _create_mock_actor( out_keys=[action_key], spec=action_spec, ) + assert actor.log_prob_keys return actor.to(device) def _create_mock_qvalue( @@ -3687,7 +3697,6 @@ def forward(self, obs, act): distribution_map={ "action1": TanhNormal, }, - aggregate_probabilities=True, ) module_out_keys = [ ("params", "action1", "loc"), @@ -4341,7 +4350,7 @@ def test_sac_tensordict_keys(self, td_est, version, composite_action_dist): "value": "state_value", "state_action_value": "state_action_value", "action": "action", - "log_prob": "sample_log_prob", + "log_prob": "action_log_prob", "reward": "reward", "done": "done", "terminated": "terminated", @@ -6771,7 +6780,7 @@ def test_redq_tensordict_keys(self, td_est): "priority": "td_error", "action": "action", "value": "state_value", - "sample_log_prob": "sample_log_prob", + "sample_log_prob": "action_log_prob", "state_action_value": "state_action_value", "reward": "reward", "done": "done", @@ -6834,12 +6843,22 @@ def test_redq_notensordict( actor_network=actor, qvalue_network=qvalue, ) - loss.set_keys( - action=action_key, - reward=reward_key, - done=done_key, - terminated=terminated_key, - ) + if deprec: + loss.set_keys( + action=action_key, + reward=reward_key, + done=done_key, + terminated=terminated_key, + log_prob=_add_suffix(action_key, "_log_prob"), + ) + else: + loss.set_keys( + action=action_key, + reward=reward_key, + done=done_key, + terminated=terminated_key, + sample_log_prob=_add_suffix(action_key, "_log_prob"), + ) kwargs = { action_key: td.get(action_key), @@ -7907,30 +7926,31 @@ def _create_mock_actor( obs_dim=3, action_dim=4, device="cpu", - action_key="action", + action_key=None, observation_key="observation", - sample_log_prob_key="sample_log_prob", + sample_log_prob_key=None, composite_action_dist=False, - aggregate_probabilities=True, ): # Actor action_spec = Bounded( -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) ) - if composite_action_dist: - action_spec = Composite({action_key: {"action1": action_spec}}) net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor()) if composite_action_dist: + if action_key is None: + action_key = ("action", "action1") + else: + action_key = (action_key, "action1") + action_spec = Composite({action_key: {"action1": action_spec}}) distribution_class = functools.partial( CompositeDistribution, distribution_map={ "action1": TanhNormal, }, name_map={ - "action1": (action_key, "action1"), + "action1": action_key, }, log_prob_key=sample_log_prob_key, - aggregate_probabilities=aggregate_probabilities, ) module_out_keys = [ ("params", "action1", "loc"), @@ -7938,6 +7958,8 @@ def _create_mock_actor( ] actor_in_keys = ["params"] else: + if action_key is None: + action_key = "action" distribution_class = TanhNormal module_out_keys = actor_in_keys = ["loc", "scale"] module = TensorDictModule( @@ -7978,7 +8000,7 @@ def _create_mock_actor_value( action_dim=4, device="cpu", composite_action_dist=False, - sample_log_prob_key="sample_log_prob", + sample_log_prob_key="action_log_prob", ): # Actor action_spec = Bounded( @@ -8000,7 +8022,6 @@ def _create_mock_actor_value( "action1": ("action", "action1"), }, log_prob_key=sample_log_prob_key, - aggregate_probabilities=True, ) module_out_keys = [ ("params", "action1", "loc"), @@ -8034,7 +8055,7 @@ def _create_mock_actor_value_shared( action_dim=4, device="cpu", composite_action_dist=False, - sample_log_prob_key="sample_log_prob", + sample_log_prob_key="action_log_prob", ): # Actor action_spec = Bounded( @@ -8057,7 +8078,6 @@ def _create_mock_actor_value_shared( "action1": ("action", "action1"), }, log_prob_key=sample_log_prob_key, - aggregate_probabilities=True, ) module_out_keys = [ ("params", "action1", "loc"), @@ -8100,7 +8120,7 @@ def _create_mock_data_ppo( reward_key="reward", done_key="done", terminated_key="terminated", - sample_log_prob_key="sample_log_prob", + sample_log_prob_key="action_log_prob", composite_action_dist=False, ): # create a tensordict @@ -8148,8 +8168,8 @@ def _create_seq_mock_data_ppo( action_dim=4, atoms=None, device="cpu", - sample_log_prob_key="sample_log_prob", - action_key="action", + sample_log_prob_key=None, + action_key=None, composite_action_dist=False, ): # create a tensordict @@ -8171,6 +8191,18 @@ def _create_seq_mock_data_ppo( params_scale = torch.rand_like(action) / 10 loc = params_mean.masked_fill_(~mask.unsqueeze(-1), 0.0) scale = params_scale.masked_fill_(~mask.unsqueeze(-1), 0.0) + if sample_log_prob_key is None: + if composite_action_dist: + sample_log_prob_key = ("action", "action1_log_prob") + else: + # conforming to composite_lp_aggregate(False) + sample_log_prob_key = "action_log_prob" + + if action_key is None: + if composite_action_dist: + action_key = ("action", "action1") + else: + action_key = "action" td = TensorDict( batch_size=(batch, T), source={ @@ -8182,7 +8214,7 @@ def _create_seq_mock_data_ppo( "reward": reward.masked_fill_(~mask.unsqueeze(-1), 0.0), }, "collector": {"mask": mask}, - action_key: {"action1": action} if composite_action_dist else action, + action_key: action, sample_log_prob_key: ( torch.randn_like(action[..., 1]) / 10 ).masked_fill_(~mask, 0.0), @@ -8262,7 +8294,15 @@ def test_ppo( loss_critic_type="l2", functional=functional, ) + if composite_action_dist: + loss_fn.set_keys( + action=("action", "action1"), + sample_log_prob=[("action", "action1_log_prob")], + ) + if advantage is not None: + advantage.set_keys(sample_log_prob=[("action", "action1_log_prob")]) if advantage is not None: + assert not composite_lp_aggregate() advantage(td) else: if td_est is not None: @@ -8322,7 +8362,6 @@ def test_ppo_composite_no_aggregate( actor = self._create_mock_actor( device=device, composite_action_dist=True, - aggregate_probabilities=False, ) value = self._create_mock_value(device=device) if advantage == "gae": @@ -8355,7 +8394,12 @@ def test_ppo_composite_no_aggregate( loss_critic_type="l2", functional=functional, ) + loss_fn.set_keys( + action=("action", "action1"), + sample_log_prob=[("action", "action1_log_prob")], + ) if advantage is not None: + advantage.set_keys(sample_log_prob=[("action", "action1_log_prob")]) advantage(td) else: if td_est is not None: @@ -8463,7 +8507,15 @@ def test_ppo_shared(self, loss_class, device, advantage, composite_action_dist): ) if advantage is not None: + if composite_action_dist: + advantage.set_keys(sample_log_prob=[("action", "action1_log_prob")]) advantage(td) + + if composite_action_dist: + loss_fn.set_keys( + action=("action", "action1"), + sample_log_prob=[("action", "action1_log_prob")], + ) loss = loss_fn(td) loss_critic = loss["loss_critic"] @@ -8570,7 +8622,20 @@ def test_ppo_shared_seq( ) if advantage is not None: + if composite_action_dist: + advantage.set_keys(sample_log_prob=[("action", "action1_log_prob")]) advantage(td) + + if composite_action_dist: + loss_fn.set_keys( + action=("action", "action1"), + sample_log_prob=[("action", "action1_log_prob")], + ) + loss_fn2.set_keys( + action=("action", "action1"), + sample_log_prob=[("action", "action1_log_prob")], + ) + loss = loss_fn(td).exclude("entropy") sum(val for key, val in loss.items() if key.startswith("loss_")).backward() @@ -8658,7 +8723,14 @@ def zero_param(p): # assert len(list(floss_fn.parameters())) == 0 with params.to_module(loss_fn): if advantage is not None: + if composite_action_dist: + advantage.set_keys(sample_log_prob=[("action", "action1_log_prob")]) advantage(td) + if composite_action_dist: + loss_fn.set_keys( + action=("action", "action1"), + sample_log_prob=[("action", "action1_log_prob")], + ) loss = loss_fn(td) loss_critic = loss["loss_critic"] @@ -8709,6 +8781,7 @@ def zero_param(p): ) @pytest.mark.parametrize("composite_action_dist", [True, False]) def test_ppo_tensordict_keys(self, loss_class, td_est, composite_action_dist): + assert not composite_lp_aggregate() actor = self._create_mock_actor(composite_action_dist=composite_action_dist) value = self._create_mock_value() @@ -8718,8 +8791,10 @@ def test_ppo_tensordict_keys(self, loss_class, td_est, composite_action_dist): "advantage": "advantage", "value_target": "value_target", "value": "state_value", - "sample_log_prob": "sample_log_prob", - "action": "action", + "sample_log_prob": "action_log_prob" + if not composite_action_dist + else ("action", "action1_log_prob"), + "action": "action" if not composite_action_dist else ("action", "action1"), "reward": "reward", "done": "done", "terminated": "terminated", @@ -8748,10 +8823,7 @@ def test_ppo_tensordict_keys(self, loss_class, td_est, composite_action_dist): @pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss)) @pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None)) @pytest.mark.parametrize("td_est", list(ValueEstimators) + [None]) - @pytest.mark.parametrize("composite_action_dist", [True, False]) - def test_ppo_tensordict_keys_run( - self, loss_class, advantage, td_est, composite_action_dist - ): + def test_ppo_tensordict_keys_run(self, loss_class, advantage, td_est): """Test PPO loss module with non-default tensordict keys.""" torch.manual_seed(self.seed) gradient_mode = True @@ -8759,18 +8831,16 @@ def test_ppo_tensordict_keys_run( "advantage": "advantage_test", "value_target": "value_target_test", "value": "state_value_test", - "sample_log_prob": "sample_log_prob_test", + "sample_log_prob": "action_log_prob_test", "action": "action_test", } td = self._create_seq_mock_data_ppo( sample_log_prob_key=tensor_keys["sample_log_prob"], action_key=tensor_keys["action"], - composite_action_dist=composite_action_dist, ) actor = self._create_mock_actor( sample_log_prob_key=tensor_keys["sample_log_prob"], - composite_action_dist=composite_action_dist, action_key=tensor_keys["action"], ) value = self._create_mock_value(out_keys=[tensor_keys["value"]]) @@ -8864,9 +8934,7 @@ def test_ppo_tensordict_keys_run( @pytest.mark.parametrize("terminated_key", ["terminated", "terminated2"]) @pytest.mark.parametrize( "composite_action_dist", - [ - False, - ], + [False], ) def test_ppo_notensordict( self, @@ -8987,11 +9055,16 @@ def test_ppo_reduction(self, reduction, loss_class, composite_action_dist): reduction=reduction, ) advantage(td) + if composite_action_dist: + loss_fn.set_keys( + action=("action", "action1"), + sample_log_prob=[("action", "action1_log_prob")], + ) loss = loss_fn(td) if reduction == "none": for key in loss.keys(): if key.startswith("loss_"): - assert loss[key].shape == td.shape + assert loss[key].shape == td.shape, key else: for key in loss.keys(): if not key.startswith("loss_"): @@ -9039,6 +9112,11 @@ def test_ppo_value_clipping( clip_value=clip_value, ) advantage(td) + if composite_action_dist: + loss_fn.set_keys( + action=("action", "action1"), + sample_log_prob=[("action", "action1_log_prob")], + ) value = td.pop(loss_fn.tensor_keys.value) @@ -9060,6 +9138,103 @@ def test_ppo_value_clipping( loss = loss_fn(td) assert "loss_critic" in loss.keys() + def test_ppo_composite_dists(self): + d = torch.distributions + + make_params = TensorDictModule( + lambda: ( + torch.ones(4), + torch.ones(4), + torch.ones(4, 2), + torch.ones(4, 2), + torch.ones(4, 10) / 10, + torch.zeros(4, 10), + torch.ones(4, 10), + ), + in_keys=[], + out_keys=[ + ("params", "gamma", "concentration"), + ("params", "gamma", "rate"), + ("params", "Kumaraswamy", "concentration0"), + ("params", "Kumaraswamy", "concentration1"), + ("params", "mixture", "logits"), + ("params", "mixture", "loc"), + ("params", "mixture", "scale"), + ], + ) + + def mixture_constructor(logits, loc, scale): + return d.MixtureSameFamily( + d.Categorical(logits=logits), d.Normal(loc=loc, scale=scale) + ) + + dist_constructor = functools.partial( + CompositeDistribution, + distribution_map={ + "gamma": d.Gamma, + "Kumaraswamy": d.Kumaraswamy, + "mixture": mixture_constructor, + }, + name_map={ + "gamma": ("agent0", "action"), + "Kumaraswamy": ("agent1", "action"), + "mixture": ("agent2", "action"), + }, + ) + policy = ProbSeq( + make_params, + ProbabilisticTensorDictModule( + in_keys=["params"], + out_keys=[ + ("agent0", "action"), + ("agent1", "action"), + ("agent2", "action"), + ], + distribution_class=dist_constructor, + return_log_prob=True, + default_interaction_type=InteractionType.RANDOM, + ), + ) + # We want to make sure there is no warning + td = policy(TensorDict(batch_size=[4])) + assert isinstance( + policy.get_dist(td).log_prob(td), + TensorDict, + ) + assert isinstance( + policy.log_prob(td), + TensorDict, + ) + value_operator = Seq( + WrapModule( + lambda td: td.set("state_value", torch.ones((*td.shape, 1))), + out_keys=["state_value"], + ) + ) + for cls in (PPOLoss, ClipPPOLoss, KLPENPPOLoss): + data = policy(TensorDict(batch_size=[4])) + data.set( + "next", + TensorDict( + reward=torch.randn(4, 1), done=torch.zeros(4, 1, dtype=torch.bool) + ), + ) + ppo = cls(policy, value_operator) + ppo.set_keys( + action=[ + ("agent0", "action"), + ("agent1", "action"), + ("agent2", "action"), + ], + sample_log_prob=[ + ("agent0", "action_log_prob"), + ("agent1", "action_log_prob"), + ("agent2", "action_log_prob"), + ], + ) + loss = ppo(data) + loss.sum(reduce=True) + class TestA2C(LossModuleTestBase): seed = 0 @@ -9072,8 +9247,8 @@ def _create_mock_actor( device="cpu", action_key="action", observation_key="observation", - sample_log_prob_key="sample_log_prob", composite_action_dist=False, + sample_log_prob_key=None, ): # Actor action_spec = Bounded( @@ -9091,8 +9266,6 @@ def _create_mock_actor( name_map={ "action1": (action_key, "action1"), }, - log_prob_key=sample_log_prob_key, - aggregate_probabilities=True, ) module_out_keys = [ ("params", "action1", "loc"), @@ -9142,7 +9315,6 @@ def _create_mock_common_layer_setup( n_hidden=2, T=10, composite_action_dist=False, - sample_log_prob_key="sample_log_prob", ): common_net = MLP( num_cells=ncells, @@ -9168,7 +9340,7 @@ def _create_mock_common_layer_setup( { "obs": torch.randn(*batch, n_obs), "action": {"action1": action} if composite_action_dist else action, - "sample_log_prob": torch.randn(*batch), + "action_log_prob": torch.randn(*batch), "done": torch.zeros(*batch, 1, dtype=torch.bool), "terminated": torch.zeros(*batch, 1, dtype=torch.bool), "next": { @@ -9192,8 +9364,6 @@ def _create_mock_common_layer_setup( name_map={ "action1": ("action", "action1"), }, - log_prob_key=sample_log_prob_key, - aggregate_probabilities=True, ) module_out_keys = [ ("params", "action1", "loc"), @@ -9234,7 +9404,7 @@ def _create_seq_mock_data_a2c( reward_key="reward", done_key="done", terminated_key="terminated", - sample_log_prob_key="sample_log_prob", + sample_log_prob_key="action_log_prob", composite_action_dist=False, ): # create a tensordict @@ -9366,6 +9536,11 @@ def set_requires_grad(tensor, requires_grad): td = td.exclude(loss_fn.tensor_keys.value_target) if advantage is not None: + advantage.set_keys( + sample_log_prob=actor.log_prob_keys + if composite_action_dist + else "action_log_prob" + ) advantage(td) elif td_est is not None: loss_fn.make_value_estimator(td_est) @@ -9585,7 +9760,7 @@ def test_a2c_tensordict_keys(self, td_est, composite_action_dist): "reward": "reward", "done": "done", "terminated": "terminated", - "sample_log_prob": "sample_log_prob", + "sample_log_prob": "action_log_prob", } self.tensordict_keys_test( @@ -9629,7 +9804,7 @@ def test_a2c_tensordict_keys_run( value_key = "state_value_test" action_key = "action_test" reward_key = "reward_test" - sample_log_prob_key = "sample_log_prob_test" + sample_log_prob_key = "action_log_prob_test" done_key = ("done", "test") terminated_key = ("terminated", "test") @@ -10073,7 +10248,7 @@ def test_reinforce_tensordict_keys(self, td_est): "advantage": "advantage", "value_target": "value_target", "value": "state_value", - "sample_log_prob": "sample_log_prob", + "sample_log_prob": "action_log_prob", "reward": "reward", "done": "done", "terminated": "terminated", @@ -10131,7 +10306,7 @@ def _create_mock_common_layer_setup( { "obs": torch.randn(*batch, n_obs), "action": torch.randn(*batch, n_act), - "sample_log_prob": torch.randn(*batch), + "action_log_prob": torch.randn(*batch), "done": torch.zeros(*batch, 1, dtype=torch.bool), "terminated": torch.zeros(*batch, 1, dtype=torch.bool), "next": { @@ -11603,7 +11778,7 @@ def _create_mock_common_layer_setup( { "obs": torch.randn(*batch, n_obs), "action": torch.randn(*batch, n_act), - "sample_log_prob": torch.randn(*batch), + "action_log_prob": torch.randn(*batch), "done": torch.zeros(*batch, 1, dtype=torch.bool), "terminated": torch.zeros(*batch, 1, dtype=torch.bool), "next": { @@ -12419,7 +12594,7 @@ def _create_mock_common_layer_setup( { "obs": torch.randn(*batch, n_obs), "action": torch.randn(*batch, n_act), - "sample_log_prob": torch.randn(*batch), + "action_log_prob": torch.randn(*batch), "done": torch.zeros(*batch, 1, dtype=torch.bool), "terminated": torch.zeros(*batch, 1, dtype=torch.bool), "next": { @@ -15043,6 +15218,7 @@ def test_successive_traj_gae( ["half", torch.half, "cpu"], ], ) +@set_composite_lp_aggregate(False) def test_shared_params(dest, expected_dtype, expected_device): if torch.cuda.device_count() == 0 and dest == "cuda": pytest.skip("no cuda device available") @@ -15147,6 +15323,13 @@ def _forward_value_estimator_keys(self, **kwargs) -> None: class TestAdv: + @pytest.fixture(scope="class", autouse=True) + def _composite_log_prob(self): + setter = set_composite_lp_aggregate(False) + setter.set() + yield + setter.unset() + @pytest.mark.parametrize( "adv,kwargs", [ @@ -15184,7 +15367,7 @@ def test_dispatch( ) kwargs = { "obs": torch.randn(1, 10, 3), - "sample_log_prob": torch.log(torch.rand(1, 10, 1)), + "action_log_prob": torch.log(torch.rand(1, 10, 1)), "next_reward": torch.randn(1, 10, 1, requires_grad=True), "next_done": torch.zeros(1, 10, 1, dtype=torch.bool), "next_terminated": torch.zeros(1, 10, 1, dtype=torch.bool), @@ -15246,7 +15429,7 @@ def test_diff_reward( td = TensorDict( { "obs": torch.randn(1, 10, 3), - "sample_log_prob": torch.log(torch.rand(1, 10, 1)), + "action_log_prob": torch.log(torch.rand(1, 10, 1)), "next": { "obs": torch.randn(1, 10, 3), "reward": torch.randn(1, 10, 1, requires_grad=True), @@ -15319,7 +15502,7 @@ def test_non_differentiable(self, adv, shifted, kwargs): td = TensorDict( { "obs": torch.randn(1, 10, 3), - "sample_log_prob": torch.log(torch.rand(1, 10, 1)), + "action_log_prob": torch.log(torch.rand(1, 10, 1)), "next": { "obs": torch.randn(1, 10, 3), "reward": torch.randn(1, 10, 1, requires_grad=True), @@ -15390,7 +15573,7 @@ def test_time_dim(self, adv, kwargs, shifted=True): td = TensorDict( { "obs": torch.randn(1, 10, 3), - "sample_log_prob": torch.log(torch.rand(1, 10, 1)), + "action_log_prob": torch.log(torch.rand(1, 10, 1)), "next": { "obs": torch.randn(1, 10, 3), "reward": torch.randn(1, 10, 1, requires_grad=True), @@ -15491,7 +15674,7 @@ def test_skip_existing( td = TensorDict( { "obs": torch.randn(1, 10, 3), - "sample_log_prob": torch.log(torch.rand(1, 10, 1)), + "action_log_prob": torch.log(torch.rand(1, 10, 1)), "state_value": torch.ones(1, 10, 1), "next": { "obs": torch.randn(1, 10, 3), @@ -15629,6 +15812,13 @@ def test_set_deprecated_keys(self, adv, kwargs): class TestBase: + @pytest.fixture(scope="class", autouse=True) + def _composite_log_prob(self): + setter = set_composite_lp_aggregate(False) + setter.set() + yield + setter.unset() + def test_decorators(self): class MyLoss(LossModule): def forward(self, tensordict: TensorDictBase) -> TensorDictBase: @@ -15848,6 +16038,13 @@ class _AcceptedKeys: class TestUtils: + @pytest.fixture(scope="class", autouse=True) + def _composite_log_prob(self): + setter = set_composite_lp_aggregate(False) + setter.set() + yield + setter.unset() + def test_standardization(self): t = torch.arange(3 * 4 * 5 * 6, dtype=torch.float32).view(3, 4, 5, 6) std_t0 = _standardize(t, exclude_dims=(1, 3)) @@ -16030,6 +16227,7 @@ def fun(a, b, time_dim=-2): (SoftUpdate, {"eps": 0.99}), ], ) +@set_composite_lp_aggregate(False) def test_updater_warning(updater, kwarg): with warnings.catch_warnings(): dqn = DQNLoss(torch.nn.Linear(3, 4), delay_value=True, action_space="one_hot") @@ -16042,6 +16240,13 @@ def test_updater_warning(updater, kwarg): class TestSingleCall: + @pytest.fixture(scope="class", autouse=True) + def _composite_log_prob(self): + setter = set_composite_lp_aggregate(False) + setter.set() + yield + setter.unset() + def _mock_value_net(self, has_target, value_key): model = nn.Linear(3, 1) module = TensorDictModule(model, in_keys=["obs"], out_keys=[value_key]) @@ -16094,6 +16299,7 @@ def test_single_call(self, has_target, value_key, single_call, detach_next=True) assert (value != value_).all() +@set_composite_lp_aggregate(False) def test_instantiate_with_different_keys(): loss_1 = DQNLoss( value_network=nn.Linear(3, 3), action_space="one_hot", delay_value=True @@ -16108,6 +16314,13 @@ def test_instantiate_with_different_keys(): class TestBuffer: + @pytest.fixture(scope="class", autouse=True) + def _composite_log_prob(self): + setter = set_composite_lp_aggregate(False) + setter.set() + yield + setter.unset() + # @pytest.mark.parametrize('dtype', (torch.double, torch.float, torch.half)) # def test_param_cast(self, dtype): # param = nn.Parameter(torch.zeros(3)) @@ -16217,6 +16430,7 @@ def __init__(self): TORCH_VERSION < version.parse("2.5.0"), reason="requires torch>=2.5" ) @pytest.mark.skipif(IS_WINDOWS, reason="windows tests do not support compile") +@set_composite_lp_aggregate(False) def test_exploration_compile(): try: torch._dynamo.reset_code_caches() @@ -16283,6 +16497,7 @@ def func(t): assert it == exploration_type() +@set_composite_lp_aggregate(False) def test_loss_exploration(): class DummyLoss(LossModule): def forward(self, td, mode): diff --git a/torchrl/modules/tensordict_module/probabilistic.py b/torchrl/modules/tensordict_module/probabilistic.py index 4eb6e702c31..392b7291df9 100644 --- a/torchrl/modules/tensordict_module/probabilistic.py +++ b/torchrl/modules/tensordict_module/probabilistic.py @@ -2,6 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations import warnings from typing import Dict, List, Optional, Type, Union @@ -86,7 +87,7 @@ class SafeProbabilisticModule(ProbabilisticTensorDictModule): distribution sample will be written in the tensordict with the key `'sample_log_prob'`. Default is ``False``. log_prob_key (NestedKey, optional): key where to write the log_prob if return_log_prob = True. - Defaults to `'sample_log_prob'`. + Defaults to `"action_log_prob"`. cache_dist (bool, optional): EXPERIMENTAL: if ``True``, the parameters of the distribution (i.e. the output of the module) will be written to the tensordict along with the sample. Those parameters can be used to re-compute @@ -108,7 +109,7 @@ def __init__( distribution_class: Type = Delta, distribution_kwargs: Optional[dict] = None, return_log_prob: bool = False, - log_prob_key: Optional[NestedKey] = "sample_log_prob", + log_prob_key: NestedKey | None = None, cache_dist: bool = False, n_empirical_estimate: int = 1000, ): @@ -140,7 +141,7 @@ def __init__( elif spec is None: spec = Composite() spec_keys = set(unravel_key_list(list(spec.keys(True, True)))) - out_keys = set(unravel_key_list(self.out_keys)) + out_keys = set(unravel_key_list(self._out_keys)) if spec_keys != out_keys: # then assume that all the non indicated specs are None for key in out_keys: diff --git a/torchrl/objectives/a2c.py b/torchrl/objectives/a2c.py index e90a188331c..7701c1a662f 100644 --- a/torchrl/objectives/a2c.py +++ b/torchrl/objectives/a2c.py @@ -16,7 +16,14 @@ TensorDictBase, TensorDictParams, ) -from tensordict.nn import dispatch, ProbabilisticTensorDictSequential, TensorDictModule +from tensordict.nn import ( + composite_lp_aggregate, + CompositeDistribution, + dispatch, + ProbabilisticTensorDictSequential, + set_composite_lp_aggregate, + TensorDictModule, +) from tensordict.utils import NestedKey from torch import distributions as d @@ -240,10 +247,17 @@ class _AcceptedKeys: reward: NestedKey = "reward" done: NestedKey = "done" terminated: NestedKey = "terminated" - sample_log_prob: NestedKey = "sample_log_prob" + sample_log_prob: NestedKey | None = None + + def __post_init__(self): + if self.sample_log_prob is None: + if composite_lp_aggregate(nowarn=True): + self.sample_log_prob = "sample_log_prob" + else: + self.sample_log_prob = "action_log_prob" + default_keys = _AcceptedKeys tensor_keys: _AcceptedKeys - default_keys = _AcceptedKeys() default_value_estimator: ValueEstimators = ValueEstimators.GAE actor_network: TensorDictModule @@ -353,6 +367,13 @@ def __init__( else: self.clip_value = None + log_prob_keys = self.actor_network.log_prob_keys + action_keys = self.actor_network.dist_sample_keys + if len(log_prob_keys) > 1: + self.set_keys(sample_log_prob=log_prob_keys, action=action_keys) + else: + self.set_keys(sample_log_prob=log_prob_keys[0], action=action_keys[0]) + @property def functional(self): return self._functional @@ -401,6 +422,7 @@ def _forward_value_estimator_keys(self, **kwargs) -> None: def reset(self) -> None: pass + @set_composite_lp_aggregate(False) def get_entropy_bonus(self, dist: d.Distribution) -> torch.Tensor: if HAS_ENTROPY.get(type(dist), False): entropy = dist.entropy() @@ -408,36 +430,39 @@ def get_entropy_bonus(self, dist: d.Distribution) -> torch.Tensor: x = dist.rsample((self.samples_mc_entropy,)) log_prob = dist.log_prob(x) if is_tensor_collection(log_prob): - log_prob = log_prob.get(self.tensor_keys.sample_log_prob) + log_prob = sum(log_prob.sum(dim="feature").values(True, True)) entropy = -log_prob.mean(0) return entropy.unsqueeze(-1) + @set_composite_lp_aggregate(False) def _log_probs( self, tensordict: TensorDictBase ) -> Tuple[torch.Tensor, d.Distribution]: # current log_prob of actions - action = tensordict.get(self.tensor_keys.action) tensordict_clone = tensordict.select( *self.actor_network.in_keys, strict=False - ).clone() + ).copy() with self.actor_network_params.to_module( self.actor_network ) if self.functional else contextlib.nullcontext(): dist = self.actor_network.get_dist(tensordict_clone) + if isinstance(dist, CompositeDistribution): + action_keys = self.tensor_keys.action + action = tensordict.select( + *((action_keys,) if isinstance(action_keys, NestedKey) else action_keys) + ) + else: + action = tensordict.get(self.tensor_keys.action) + if action.requires_grad: raise RuntimeError( f"tensordict stored {self.tensor_keys.action} requires grad." ) - if isinstance(action, torch.Tensor): - log_prob = dist.log_prob(action) - else: - maybe_log_prob = dist.log_prob(tensordict) - if not isinstance(maybe_log_prob, torch.Tensor): - # In some cases (Composite distribution with aggregate_probabilities toggled off) the returned type may not - # be a tensor - log_prob = maybe_log_prob.get(self.tensor_keys.sample_log_prob) - else: - log_prob = maybe_log_prob + log_prob = dist.log_prob(action) + if not isinstance(action, torch.Tensor): + log_prob = sum( + dist.log_prob(tensordict).sum(dim="feature").values(True, True) + ) log_prob = log_prob.unsqueeze(-1) return log_prob, dist diff --git a/torchrl/objectives/common.py b/torchrl/objectives/common.py index 0cda513e419..ab5a564abcf 100644 --- a/torchrl/objectives/common.py +++ b/torchrl/objectives/common.py @@ -247,7 +247,7 @@ def set_keys(self, **kwargs) -> None: if value is not None: setattr(self.tensor_keys, key, value) else: - setattr(self.tensor_keys, key, self.default_keys.key) + setattr(self.tensor_keys, key, self.default_keys().key) try: self._forward_value_estimator_keys(**kwargs) diff --git a/torchrl/objectives/cql.py b/torchrl/objectives/cql.py index 4c320dec46e..720e63052db 100644 --- a/torchrl/objectives/cql.py +++ b/torchrl/objectives/cql.py @@ -261,7 +261,7 @@ class _AcceptedKeys: terminated: NestedKey = "terminated" tensor_keys: _AcceptedKeys - default_keys = _AcceptedKeys() + default_keys = _AcceptedKeys default_value_estimator = ValueEstimators.TD0 actor_network: TensorDictModule @@ -1026,7 +1026,7 @@ class _AcceptedKeys: pred_val: NestedKey = "pred_val" tensor_keys: _AcceptedKeys - default_keys = _AcceptedKeys() + default_keys = _AcceptedKeys default_value_estimator = ValueEstimators.TD0 out_keys = [ "loss_qvalue", diff --git a/torchrl/objectives/crossq.py b/torchrl/objectives/crossq.py index 9af95581344..b0696675902 100644 --- a/torchrl/objectives/crossq.py +++ b/torchrl/objectives/crossq.py @@ -243,7 +243,7 @@ class _AcceptedKeys: log_prob: NestedKey = "_log_prob" tensor_keys: _AcceptedKeys - default_keys = _AcceptedKeys() + default_keys = _AcceptedKeys default_value_estimator = ValueEstimators.TD0 actor_network: ProbabilisticActor diff --git a/torchrl/objectives/ddpg.py b/torchrl/objectives/ddpg.py index 26f7d128601..fde8df0a93a 100644 --- a/torchrl/objectives/ddpg.py +++ b/torchrl/objectives/ddpg.py @@ -174,7 +174,7 @@ class _AcceptedKeys: terminated: NestedKey = "terminated" tensor_keys: _AcceptedKeys - default_keys = _AcceptedKeys() + default_keys = _AcceptedKeys default_value_estimator: ValueEstimators = ValueEstimators.TD0 out_keys = [ "loss_actor", diff --git a/torchrl/objectives/decision_transformer.py b/torchrl/objectives/decision_transformer.py index 5e24edf548e..770ae5d05ce 100644 --- a/torchrl/objectives/decision_transformer.py +++ b/torchrl/objectives/decision_transformer.py @@ -71,7 +71,7 @@ class _AcceptedKeys: action_pred: NestedKey = "action" tensor_keys: _AcceptedKeys - default_keys = _AcceptedKeys() + default_keys = _AcceptedKeys actor_network: TensorDictModule actor_network_params: TensorDictParams @@ -282,7 +282,7 @@ class _AcceptedKeys: action_pred: NestedKey = "action" tensor_keys: _AcceptedKeys - default_keys = _AcceptedKeys() + default_keys = _AcceptedKeys actor_network: TensorDictModule actor_network_params: TensorDictParams diff --git a/torchrl/objectives/deprecated.py b/torchrl/objectives/deprecated.py index 7f795706640..1e736e878dc 100644 --- a/torchrl/objectives/deprecated.py +++ b/torchrl/objectives/deprecated.py @@ -13,7 +13,7 @@ import torch from tensordict import TensorDict, TensorDictBase, TensorDictParams -from tensordict.nn import dispatch, TensorDictModule +from tensordict.nn import composite_lp_aggregate, dispatch, TensorDictModule from tensordict.utils import NestedKey from torch import Tensor @@ -121,14 +121,21 @@ class _AcceptedKeys: action: NestedKey = "action" state_action_value: NestedKey = "state_action_value" value: NestedKey = "state_value" - log_prob: NestedKey = "_log_prob" + log_prob: NestedKey | None = None priority: NestedKey = "td_error" reward: NestedKey = "reward" done: NestedKey = "done" terminated: NestedKey = "terminated" + def __post_init__(self): + if self.log_prob is None: + if composite_lp_aggregate(nowarn=True): + self.log_prob = "sample_log_prob" + else: + self.log_prob = "action_log_prob" + tensor_keys: _AcceptedKeys - default_keys = _AcceptedKeys() + default_keys = _AcceptedKeys delay_actor: bool = False default_value_estimator = ValueEstimators.TD0 @@ -359,12 +366,14 @@ def _actor_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, Tensor]: tensordict_clone.select(*self.qvalue_network.in_keys, strict=False), self._cached_detach_qvalue_network_params, ) - state_action_value = tensordict_expand.get("state_action_value").squeeze(-1) + state_action_value = tensordict_expand.get( + self.tensor_keys.state_action_value + ).squeeze(-1) loss_actor = -( state_action_value - - self.alpha * tensordict_clone.get("sample_log_prob").squeeze(-1) + - self.alpha * tensordict_clone.get(self.tensor_keys.log_prob).squeeze(-1) ) - return loss_actor, tensordict_clone.get("sample_log_prob") + return loss_actor, tensordict_clone.get(self.tensor_keys.log_prob) def _qvalue_loss(self, tensordict: TensorDictBase) -> Tensor: tensordict_save = tensordict @@ -389,30 +398,33 @@ def _qvalue_loss(self, tensordict: TensorDictBase) -> Tensor: ExplorationType.RANDOM ), self.target_actor_network_params.to_module(self.actor_network): self.actor_network(next_td) - sample_log_prob = next_td.get("sample_log_prob") + sample_log_prob = next_td.get(self.tensor_keys.log_prob) # get q-values next_td = self._vmap_qvalue_networkN0( next_td, selected_q_params, ) - state_action_value = next_td.get("state_action_value") + state_action_value = next_td.get(self.tensor_keys.state_action_value) if ( state_action_value.shape[-len(sample_log_prob.shape) :] != sample_log_prob.shape ): sample_log_prob = sample_log_prob.unsqueeze(-1) next_state_value = ( - next_td.get("state_action_value") - self.alpha * sample_log_prob + next_td.get(self.tensor_keys.state_action_value) + - self.alpha * sample_log_prob ) next_state_value = next_state_value.min(0)[0] - tensordict.set(("next", "state_value"), next_state_value) + tensordict.set(("next", self.tensor_keys.value), next_state_value) target_value = self.value_estimator.value_estimate(tensordict).squeeze(-1) tensordict_expand = self._vmap_qvalue_networkN0( tensordict.select(*self.qvalue_network.in_keys, strict=False), self.qvalue_network_params, ) - pred_val = tensordict_expand.get("state_action_value").squeeze(-1) + pred_val = tensordict_expand.get(self.tensor_keys.state_action_value).squeeze( + -1 + ) td_error = abs(pred_val - target_value) loss_qval = distance_loss( pred_val, diff --git a/torchrl/objectives/dqn.py b/torchrl/objectives/dqn.py index d025018e9c7..47fc0508397 100644 --- a/torchrl/objectives/dqn.py +++ b/torchrl/objectives/dqn.py @@ -165,7 +165,7 @@ class _AcceptedKeys: terminated: NestedKey = "terminated" tensor_keys: _AcceptedKeys - default_keys = _AcceptedKeys() + default_keys = _AcceptedKeys default_value_estimator = ValueEstimators.TD0 out_keys = ["loss"] @@ -437,7 +437,7 @@ class _AcceptedKeys: steps_to_next_obs: NestedKey = "steps_to_next_obs" tensor_keys: _AcceptedKeys - default_keys = _AcceptedKeys() + default_keys = _AcceptedKeys default_value_estimator = ValueEstimators.TD0 value_network: TensorDictModule diff --git a/torchrl/objectives/dreamer.py b/torchrl/objectives/dreamer.py index a92d5bfeedd..6c03461ce76 100644 --- a/torchrl/objectives/dreamer.py +++ b/torchrl/objectives/dreamer.py @@ -90,7 +90,7 @@ class _AcceptedKeys: reco_pixels: NestedKey = "reco_pixels" tensor_keys: _AcceptedKeys - default_keys = _AcceptedKeys() + default_keys = _AcceptedKeys decoder: TensorDictModule reward_model: TensorDictModule @@ -244,7 +244,7 @@ class _AcceptedKeys: terminated: NestedKey = "terminated" tensor_keys: _AcceptedKeys - default_keys = _AcceptedKeys() + default_keys = _AcceptedKeys default_value_estimator = ValueEstimators.TDLambda value_model: TensorDictModule @@ -402,7 +402,7 @@ class _AcceptedKeys: value: NestedKey = "state_value" tensor_keys: _AcceptedKeys - default_keys = _AcceptedKeys() + default_keys = _AcceptedKeys value_model: TensorDictModule diff --git a/torchrl/objectives/gail.py b/torchrl/objectives/gail.py index ff95b0036ee..bece855ad62 100644 --- a/torchrl/objectives/gail.py +++ b/torchrl/objectives/gail.py @@ -60,7 +60,7 @@ class _AcceptedKeys: discriminator_pred: NestedKey = "d_logits" tensor_keys: _AcceptedKeys - default_keys = _AcceptedKeys() + default_keys = _AcceptedKeys discriminator_network: TensorDictModule discriminator_network_params: TensorDictParams diff --git a/torchrl/objectives/iql.py b/torchrl/objectives/iql.py index 300105c1ba7..ea84d59939a 100644 --- a/torchrl/objectives/iql.py +++ b/torchrl/objectives/iql.py @@ -234,7 +234,7 @@ class _AcceptedKeys: terminated: NestedKey = "terminated" tensor_keys: _AcceptedKeys - default_keys = _AcceptedKeys() + default_keys = _AcceptedKeys default_value_estimator = ValueEstimators.TD0 out_keys = [ "loss_actor", @@ -711,7 +711,7 @@ class _AcceptedKeys: terminated: NestedKey = "terminated" tensor_keys: _AcceptedKeys - default_keys = _AcceptedKeys() + default_keys = _AcceptedKeys default_value_estimator = ValueEstimators.TD0 out_keys = [ "loss_actor", diff --git a/torchrl/objectives/multiagent/qmixer.py b/torchrl/objectives/multiagent/qmixer.py index 793a335f1b9..93b50f3e76b 100644 --- a/torchrl/objectives/multiagent/qmixer.py +++ b/torchrl/objectives/multiagent/qmixer.py @@ -180,7 +180,7 @@ class _AcceptedKeys: terminated: NestedKey = "terminated" tensor_keys: _AcceptedKeys - default_keys = _AcceptedKeys() + default_keys = _AcceptedKeys default_value_estimator = ValueEstimators.TD0 out_keys = ["loss"] diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index cd22e03323c..bf7831d518c 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -5,10 +5,11 @@ from __future__ import annotations import contextlib +import warnings from copy import deepcopy from dataclasses import dataclass -from typing import Tuple +from typing import List, Tuple import torch from tensordict import ( @@ -18,10 +19,12 @@ TensorDictParams, ) from tensordict.nn import ( + composite_lp_aggregate, CompositeDistribution, dispatch, ProbabilisticTensorDictModule, ProbabilisticTensorDictSequential, + set_composite_lp_aggregate, TensorDictModule, ) from tensordict.utils import NestedKey @@ -34,6 +37,8 @@ _cache_values, _clip_value_loss, _GAMMA_LMBDA_DEPREC_ERROR, + _maybe_add_or_extend_key, + _maybe_get_or_select, _reduce, _sum_td_features, default_value_kwargs, @@ -47,7 +52,6 @@ TDLambdaEstimator, VTrace, ) -from yaml import warnings class PPOLoss(LossModule): @@ -69,7 +73,10 @@ class PPOLoss(LossModule): Args: actor_network (ProbabilisticTensorDictSequential): policy operator. - critic_network (ValueOperator): value operator. + Typically, a :class:`~tensordict.nn.ProbabilisticTensorDictSequential` subclass taking observations + as input and outputting an action (or actions) as well as its log-probability value. + critic_network (ValueOperator): value operator. The critic will usually take the observations as input + and return a scalar value (``state_value`` by default) in the output keys. Keyword Args: entropy_bonus (bool, optional): if ``True``, an entropy bonus will be added to the @@ -272,16 +279,18 @@ class _AcceptedKeys: Will be used for the underlying value estimator Defaults to ``"value_target"``. value (NestedKey): The input tensordict key where the state value is expected. Will be used for the underlying value estimator. Defaults to ``"state_value"``. - sample_log_prob (NestedKey): The input tensordict key where the - sample log probability is expected. Defaults to ``"sample_log_prob"``. - action (NestedKey): The input tensordict key where the action is expected. + sample_log_prob (NestedKey or list of nested keys): The input tensordict key where the + sample log probability is expected. + Defaults to ``"sample_log_prob"`` when :func:`~tensordict.nn.composite_lp_aggregate` returns `True`, + `"action_log_prob"` otherwise. + action (NestedKey or list of nested keys): The input tensordict key where the action is expected. Defaults to ``"action"``. - reward (NestedKey): The input tensordict key where the reward is expected. + reward (NestedKey or list of nested keys): The input tensordict key where the reward is expected. Will be used for the underlying value estimator. Defaults to ``"reward"``. - done (NestedKey): The key in the input TensorDict that indicates + done (NestedKey or list of nested keys): The key in the input TensorDict that indicates whether a trajectory is done. Will be used for the underlying value estimator. Defaults to ``"done"``. - terminated (NestedKey): The key in the input TensorDict that indicates + terminated (NestedKey or list of nested keys): The key in the input TensorDict that indicates whether a trajectory is terminated. Will be used for the underlying value estimator. Defaults to ``"terminated"``. """ @@ -289,17 +298,24 @@ class _AcceptedKeys: advantage: NestedKey = "advantage" value_target: NestedKey = "value_target" value: NestedKey = "state_value" - sample_log_prob: NestedKey = "sample_log_prob" - action: NestedKey = "action" - reward: NestedKey = "reward" - done: NestedKey = "done" - terminated: NestedKey = "terminated" - + sample_log_prob: NestedKey | List[NestedKey] | None = None + action: NestedKey | List[NestedKey] = "action" + reward: NestedKey | List[NestedKey] = "reward" + done: NestedKey | List[NestedKey] = "done" + terminated: NestedKey | List[NestedKey] = "terminated" + + def __post_init__(self): + if self.sample_log_prob is None: + if composite_lp_aggregate(nowarn=True): + self.sample_log_prob = "sample_log_prob" + else: + self.sample_log_prob = "action_log_prob" + + default_keys = _AcceptedKeys tensor_keys: _AcceptedKeys - default_keys = _AcceptedKeys() default_value_estimator = ValueEstimators.GAE - actor_network: TensorDictModule + actor_network: ProbabilisticTensorDictModule critic_network: TensorDictModule actor_network_params: TensorDictParams critic_network_params: TensorDictParams @@ -376,7 +392,7 @@ def __init__( try: device = next(self.parameters()).device - except AttributeError: + except (AttributeError, StopIteration): device = torch.device("cpu") self.register_buffer("entropy_coef", torch.tensor(entropy_coef, device=device)) @@ -411,22 +427,28 @@ def __init__( f"clip_value must be a float or a scalar tensor, got {clip_value}." ) self.register_buffer("clip_value", clip_value) + log_prob_keys = self.actor_network.log_prob_keys + action_keys = self.actor_network.dist_sample_keys + if len(log_prob_keys) > 1: + self.set_keys(sample_log_prob=log_prob_keys, action=action_keys) + else: + self.set_keys(sample_log_prob=log_prob_keys[0], action=action_keys[0]) @property def functional(self): return self._functional def _set_in_keys(self): - keys = [ - self.tensor_keys.action, - self.tensor_keys.sample_log_prob, - ("next", self.tensor_keys.reward), - ("next", self.tensor_keys.done), - ("next", self.tensor_keys.terminated), - *self.actor_network.in_keys, - *[("next", key) for key in self.actor_network.in_keys], - *self.critic_network.in_keys, - ] + keys = [] + _maybe_add_or_extend_key(keys, self.actor_network.in_keys) + _maybe_add_or_extend_key(keys, self.actor_network.in_keys, "next") + _maybe_add_or_extend_key(keys, self.critic_network.in_keys) + _maybe_add_or_extend_key(keys, self.tensor_keys.action) + _maybe_add_or_extend_key(keys, self.tensor_keys.sample_log_prob) + _maybe_add_or_extend_key(keys, self.tensor_keys.reward, "next") + _maybe_add_or_extend_key(keys, self.tensor_keys.done, "next") + _maybe_add_or_extend_key(keys, self.tensor_keys.terminated, "next") + self._in_keys = list(set(keys)) @property @@ -465,6 +487,7 @@ def _forward_value_estimator_keys(self, **kwargs) -> None: reward=self.tensor_keys.reward, done=self.tensor_keys.done, terminated=self.tensor_keys.terminated, + sample_log_prob=self.tensor_keys.sample_log_prob, ) self._set_in_keys() @@ -473,33 +496,56 @@ def reset(self) -> None: def get_entropy_bonus(self, dist: d.Distribution) -> torch.Tensor: try: - if isinstance(dist, CompositeDistribution): - kwargs = {"aggregate_probabilities": False, "include_sum": False} - else: - kwargs = {} - entropy = dist.entropy(**kwargs) - if is_tensor_collection(entropy): - entropy = _sum_td_features(entropy) + entropy = dist.entropy() except NotImplementedError: - x = dist.rsample((self.samples_mc_entropy,)) - log_prob = dist.log_prob(x) - if is_tensor_collection(log_prob): - log_prob = log_prob.get(self.tensor_keys.sample_log_prob) + if getattr(dist, "has_rsample", False): + x = dist.rsample((self.samples_mc_entropy,)) + else: + x = dist.sample((self.samples_mc_entropy,)) + with set_composite_lp_aggregate(False) if isinstance( + dist, CompositeDistribution + ) else contextlib.nullcontext(): + log_prob = dist.log_prob(x) + if is_tensor_collection(log_prob): + if isinstance(self.tensor_keys.sample_log_prob, NestedKey): + log_prob = log_prob.get(self.tensor_keys.sample_log_prob) + else: + log_prob = log_prob.select(*self.tensor_keys.sample_log_prob) + entropy = -log_prob.mean(0) + if is_tensor_collection(entropy): + entropy = _sum_td_features(entropy) return entropy.unsqueeze(-1) def _log_weight( self, tensordict: TensorDictBase ) -> Tuple[torch.Tensor, d.Distribution]: - # current log_prob of actions - action = tensordict.get(self.tensor_keys.action) with self.actor_network_params.to_module( self.actor_network ) if self.functional else contextlib.nullcontext(): dist = self.actor_network.get_dist(tensordict) + if isinstance(dist, CompositeDistribution): + is_composite = True + else: + is_composite = False + + # current log_prob of actions + if is_composite: + action = tensordict.select( + *( + (self.tensor_keys.action,) + if isinstance(self.tensor_keys.action, NestedKey) + else self.tensor_keys.action + ) + ) + else: + action = _maybe_get_or_select(tensordict, self.tensor_keys.action) + + prev_log_prob = _maybe_get_or_select( + tensordict, self.tensor_keys.sample_log_prob + ) - prev_log_prob = tensordict.get(self.tensor_keys.sample_log_prob) if prev_log_prob.requires_grad: raise RuntimeError( f"tensordict stored {self.tensor_keys.sample_log_prob} requires grad." @@ -509,26 +555,29 @@ def _log_weight( raise RuntimeError( f"tensordict stored {self.tensor_keys.action} requires grad." ) - if isinstance(action, torch.Tensor): - log_prob = dist.log_prob(action) - else: - if isinstance(dist, CompositeDistribution): - is_composite = True - kwargs = { - "inplace": False, - "aggregate_probabilities": False, - "include_sum": False, - } - else: - is_composite = False - kwargs = {} - log_prob = dist.log_prob(tensordict, **kwargs) - if is_composite and not isinstance(prev_log_prob, TensorDict): - log_prob = _sum_td_features(log_prob) - log_prob.view_as(prev_log_prob) + log_prob = dist.log_prob(action) + if is_composite: + with set_composite_lp_aggregate(False): + if not is_tensor_collection(prev_log_prob): + # this isn't great, in general multihead actions should have a composite log-prob too + warnings.warn( + "You are using a composite distribution, yet your log-probability is a tensor. " + "Make sure you have called tensordict.nn.set_composite_lp_aggregate(False).set() at " + "the beginning of your script to get a proper composite log-prob.", + category=UserWarning, + ) + if ( + is_composite + and not is_tensor_collection(prev_log_prob) + and is_tensor_collection(log_prob) + ): + log_prob = _sum_td_features(log_prob) + log_prob.view_as(prev_log_prob) log_weight = (log_prob - prev_log_prob).unsqueeze(-1) kl_approx = (prev_log_prob - log_prob).unsqueeze(-1) + if is_tensor_collection(kl_approx): + kl_approx = _sum_td_features(kl_approx) return log_weight, dist, kl_approx @@ -910,6 +959,8 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: # to different, unrelated trajectories, which is not standard. Still it can give a idea of the dispersion # of the weights. lw = log_weight.squeeze() + if not isinstance(lw, torch.Tensor): + lw = _sum_td_features(lw) ess = (2 * lw.logsumexp(0) - (2 * lw).logsumexp(0)).exp() batch = log_weight.shape[0] @@ -920,7 +971,9 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: ratio = log_weight_clip.exp() gain2 = ratio * advantage - gain = torch.stack([gain1, gain2], -1).min(dim=-1)[0] + gain = torch.stack([gain1, gain2], -1).min(dim=-1).values + if is_tensor_collection(gain): + gain = _sum_td_features(gain) td_out = TensorDict({"loss_objective": -gain}, batch_size=[]) td_out.set("clip_fraction", clip_fraction) @@ -1120,16 +1173,16 @@ def __init__( self.samples_mc_kl = samples_mc_kl def _set_in_keys(self): - keys = [ - self.tensor_keys.action, - self.tensor_keys.sample_log_prob, - ("next", self.tensor_keys.reward), - ("next", self.tensor_keys.done), - ("next", self.tensor_keys.terminated), - *self.actor_network.in_keys, - *[("next", key) for key in self.actor_network.in_keys], - *self.critic_network.in_keys, - ] + keys = [] + _maybe_add_or_extend_key(keys, self.actor_network.in_keys) + _maybe_add_or_extend_key(keys, self.actor_network.in_keys, "next") + _maybe_add_or_extend_key(keys, self.critic_network.in_keys) + _maybe_add_or_extend_key(keys, self.tensor_keys.action) + _maybe_add_or_extend_key(keys, self.tensor_keys.sample_log_prob) + _maybe_add_or_extend_key(keys, self.tensor_keys.reward, "next") + _maybe_add_or_extend_key(keys, self.tensor_keys.done, "next") + _maybe_add_or_extend_key(keys, self.tensor_keys.terminated, "next") + # Get the parameter keys from the actor dist actor_dist_module = None for module in self.actor_network.modules(): @@ -1197,27 +1250,26 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict: log_weight, dist, kl_approx = self._log_weight(tensordict_copy) neg_loss = log_weight.exp() * advantage + if is_tensor_collection(neg_loss): + neg_loss = _sum_td_features(neg_loss) with self.actor_network_params.to_module( self.actor_network ) if self.functional else contextlib.nullcontext(): current_dist = self.actor_network.get_dist(tensordict_copy) + is_composite = isinstance(current_dist, CompositeDistribution) try: kl = torch.distributions.kl.kl_divergence(previous_dist, current_dist) except NotImplementedError: x = previous_dist.sample((self.samples_mc_kl,)) - if isinstance(previous_dist, CompositeDistribution): - kwargs = { - "aggregate_probabilities": False, - "inplace": False, - "include_sum": False, - } - else: - kwargs = {} - previous_log_prob = previous_dist.log_prob(x, **kwargs) - current_log_prob = current_dist.log_prob(x, **kwargs) - if is_tensor_collection(current_log_prob): + with set_composite_lp_aggregate( + False + ) if is_composite else contextlib.nullcontext(): + previous_log_prob = previous_dist.log_prob(x) + current_log_prob = current_dist.log_prob(x) + if is_tensor_collection(previous_log_prob): previous_log_prob = _sum_td_features(previous_log_prob) + # Both dists have presumably the same params current_log_prob = _sum_td_features(current_log_prob) kl = (previous_log_prob - current_log_prob).mean(0) kl = kl.unsqueeze(-1) diff --git a/torchrl/objectives/redq.py b/torchrl/objectives/redq.py index 9ed3a7f8f3e..6e280e1f0fa 100644 --- a/torchrl/objectives/redq.py +++ b/torchrl/objectives/redq.py @@ -12,7 +12,7 @@ import torch from tensordict import TensorDict, TensorDictBase, TensorDictParams -from tensordict.nn import dispatch, TensorDictModule +from tensordict.nn import composite_lp_aggregate, dispatch, TensorDictModule from tensordict.utils import NestedKey from torch import Tensor @@ -207,7 +207,9 @@ class _AcceptedKeys: Will be used for the underlying value estimator. Defaults to ``"state_value"``. action (NestedKey): The input tensordict key where the action is expected. Defaults to ``"action"``. sample_log_prob (NestedKey): The input tensordict key where the - sample log probability is expected. Defaults to ``"sample_log_prob"``. + sample log probability is expected. + Defaults to ``"sample_log_prob"`` when :func:`~tensordict.nn.composite_lp_aggregate` returns `True`, + `"action_log_prob"` otherwise. priority (NestedKey): The input tensordict key where the target priority is written to. Defaults to ``"td_error"``. state_action_value (NestedKey): The input tensordict key where the @@ -224,15 +226,22 @@ class _AcceptedKeys: action: NestedKey = "action" value: NestedKey = "state_value" - sample_log_prob: NestedKey = "sample_log_prob" + sample_log_prob: NestedKey | None = None priority: NestedKey = "td_error" state_action_value: NestedKey = "state_action_value" reward: NestedKey = "reward" done: NestedKey = "done" terminated: NestedKey = "terminated" + def __post_init__(self): + if self.sample_log_prob is None: + if composite_lp_aggregate(nowarn=True): + self.sample_log_prob = "sample_log_prob" + else: + self.sample_log_prob = "action_log_prob" + tensor_keys: _AcceptedKeys - default_keys = _AcceptedKeys() + default_keys = _AcceptedKeys delay_actor: bool = False default_value_estimator = ValueEstimators.TD0 out_keys = [ diff --git a/torchrl/objectives/reinforce.py b/torchrl/objectives/reinforce.py index 4334016503f..0cac502b347 100644 --- a/torchrl/objectives/reinforce.py +++ b/torchrl/objectives/reinforce.py @@ -11,7 +11,12 @@ import torch from tensordict import TensorDict, TensorDictBase, TensorDictParams -from tensordict.nn import dispatch, ProbabilisticTensorDictSequential, TensorDictModule +from tensordict.nn import ( + composite_lp_aggregate, + dispatch, + ProbabilisticTensorDictSequential, + TensorDictModule, +) from tensordict.utils import NestedKey from torchrl.objectives.common import LossModule @@ -189,7 +194,8 @@ class _AcceptedKeys: value (NestedKey): The input tensordict key where the state value is expected. Will be used for the underlying value estimator. Defaults to ``"state_value"``. sample_log_prob (NestedKey): The input tensordict key where the sample log probability is expected. - Defaults to ``"sample_log_prob"``. + Defaults to ``"sample_log_prob"`` when :func:`~tensordict.nn.composite_lp_aggregate` returns `True`, + `"action_log_prob"` otherwise. action (NestedKey): The input tensordict key where the action is expected. Defaults to ``"action"``. reward (NestedKey): The input tensordict key where the reward is expected. @@ -205,14 +211,21 @@ class _AcceptedKeys: advantage: NestedKey = "advantage" value_target: NestedKey = "value_target" value: NestedKey = "state_value" - sample_log_prob: NestedKey = "sample_log_prob" + sample_log_prob: NestedKey | None = None action: NestedKey = "action" reward: NestedKey = "reward" done: NestedKey = "done" terminated: NestedKey = "terminated" + def __post_init__(self): + if self.sample_log_prob is None: + if composite_lp_aggregate(nowarn=True): + self.sample_log_prob = "sample_log_prob" + else: + self.sample_log_prob = "action_log_prob" + tensor_keys: _AcceptedKeys - default_keys = _AcceptedKeys() + default_keys = _AcceptedKeys default_value_estimator = ValueEstimators.GAE out_keys = ["loss_actor", "loss_value"] diff --git a/torchrl/objectives/sac.py b/torchrl/objectives/sac.py index 9d9790d52b5..64b09ea0433 100644 --- a/torchrl/objectives/sac.py +++ b/torchrl/objectives/sac.py @@ -15,7 +15,13 @@ import torch from tensordict import TensorDict, TensorDictBase, TensorDictParams -from tensordict.nn import dispatch, TensorDictModule +from tensordict.nn import ( + composite_lp_aggregate, + CompositeDistribution, + dispatch, + set_composite_lp_aggregate, + TensorDictModule, +) from tensordict.utils import expand_right, NestedKey from torch import Tensor from torchrl.data.tensor_specs import Composite, TensorSpec @@ -46,17 +52,13 @@ def new_func(self, *args, **kwargs): return new_func -def compute_log_prob(action_dist, action_or_tensordict, tensor_key): +def compute_log_prob(action_dist, action_or_tensordict, tensor_key) -> torch.Tensor: """Compute the log probability of an action given a distribution.""" - if isinstance(action_or_tensordict, torch.Tensor): - log_p = action_dist.log_prob(action_or_tensordict) - else: - maybe_log_prob = action_dist.log_prob(action_or_tensordict) - if not isinstance(maybe_log_prob, torch.Tensor): - log_p = maybe_log_prob.get(tensor_key) - else: - log_p = maybe_log_prob - return log_p + lp = action_dist.log_prob(action_or_tensordict) + if isinstance(action_dist, CompositeDistribution): + with set_composite_lp_aggregate(False): + return sum(lp.sum(dim="feature").values(True, True)) + return lp class SACLoss(LossModule): @@ -268,7 +270,8 @@ class _AcceptedKeys: state_action_value (NestedKey): The input tensordict key where the state action value is expected. Defaults to ``"state_action_value"``. log_prob (NestedKey): The input tensordict key where the log probability is expected. - Defaults to ``"sample_log_prob"``. + Defaults to ``"sample_log_prob"`` when :func:`~tensordict.nn.composite_lp_aggregate` returns `True`, + `"action_log_prob"` otherwise. priority (NestedKey): The input tensordict key where the target priority is written to. Defaults to ``"td_error"``. reward (NestedKey): The input tensordict key where the reward is expected. @@ -284,14 +287,21 @@ class _AcceptedKeys: action: NestedKey = "action" value: NestedKey = "state_value" state_action_value: NestedKey = "state_action_value" - log_prob: NestedKey = "sample_log_prob" + log_prob: NestedKey | None = None priority: NestedKey = "td_error" reward: NestedKey = "reward" done: NestedKey = "done" terminated: NestedKey = "terminated" + def __post_init__(self): + if self.log_prob is None: + if composite_lp_aggregate(nowarn=True): + self.log_prob = "sample_log_prob" + else: + self.log_prob = "action_log_prob" + + default_keys = _AcceptedKeys tensor_keys: _AcceptedKeys - default_keys = _AcceptedKeys() default_value_estimator = ValueEstimators.TD0 actor_network: TensorDictModule @@ -426,6 +436,13 @@ def __init__( self.reduction = reduction self.skip_done_states = skip_done_states + log_prob_keys = getattr(self.actor_network, "log_prob_keys", []) + action_keys = getattr(self.actor_network, "dist_sample_keys", []) + if len(log_prob_keys) > 1: + self.set_keys(log_prob=log_prob_keys, action=action_keys) + else: + self.set_keys(log_prob=log_prob_keys[0], action=action_keys[0]) + def _make_vmap(self): self._vmap_qnetworkN0 = _vmap_func( self.qvalue_network, (None, 0), randomness=self.vmap_randomness @@ -1031,7 +1048,7 @@ class _AcceptedKeys: log_prob: NestedKey = "log_prob" tensor_keys: _AcceptedKeys - default_keys = _AcceptedKeys() + default_keys = _AcceptedKeys default_value_estimator = ValueEstimators.TD0 delay_actor: bool = False out_keys = [ diff --git a/torchrl/objectives/td3.py b/torchrl/objectives/td3.py index 20dcf19dce3..f5d67eea164 100644 --- a/torchrl/objectives/td3.py +++ b/torchrl/objectives/td3.py @@ -205,7 +205,7 @@ class _AcceptedKeys: terminated: NestedKey = "terminated" tensor_keys: _AcceptedKeys - default_keys = _AcceptedKeys() + default_keys = _AcceptedKeys default_value_estimator = ValueEstimators.TD0 out_keys = [ "loss_actor", diff --git a/torchrl/objectives/td3_bc.py b/torchrl/objectives/td3_bc.py index 45a76e80a53..deb5b844500 100644 --- a/torchrl/objectives/td3_bc.py +++ b/torchrl/objectives/td3_bc.py @@ -218,7 +218,7 @@ class _AcceptedKeys: terminated: NestedKey = "terminated" tensor_keys: _AcceptedKeys - default_keys = _AcceptedKeys() + default_keys = _AcceptedKeys default_value_estimator = ValueEstimators.TD0 out_keys = [ "loss_actor", diff --git a/torchrl/objectives/utils.py b/torchrl/objectives/utils.py index 9c46fc98262..3e0b97de710 100644 --- a/torchrl/objectives/utils.py +++ b/torchrl/objectives/utils.py @@ -8,10 +8,10 @@ import re import warnings from enum import Enum -from typing import Iterable, Optional, Union +from typing import Iterable, List, Optional, Union import torch -from tensordict import TensorDict, TensorDictBase +from tensordict import NestedKey, TensorDict, TensorDictBase, unravel_key from tensordict.nn import TensorDictModule from torch import nn, Tensor from torch.nn import functional as F @@ -620,3 +620,26 @@ def group_optimizers(*optimizers: torch.optim.Optimizer) -> torch.optim.Optimize def _sum_td_features(data: TensorDictBase) -> torch.Tensor: # Sum all features and return a tensor return data.sum(dim="feature", reduce=True) + + +def _maybe_get_or_select(td, key_or_keys): + if isinstance(key_or_keys, (str, tuple)): + return td.get(key_or_keys) + return td.select(*key_or_keys) + + +def _maybe_add_or_extend_key( + tensor_keys: List[NestedKey], + key_or_list_of_keys: NestedKey | List[NestedKey], + prefix: NestedKey = None, +): + if prefix is not None: + if isinstance(key_or_list_of_keys, NestedKey): + tensor_keys.append(unravel_key((prefix, key_or_list_of_keys))) + else: + tensor_keys.extend([unravel_key((prefix, k)) for k in key_or_list_of_keys]) + return + if isinstance(key_or_list_of_keys, NestedKey): + tensor_keys.append(key_or_list_of_keys) + else: + tensor_keys.extend(key_or_list_of_keys) diff --git a/torchrl/objectives/value/advantages.py b/torchrl/objectives/value/advantages.py index dd3f9cc4589..dd253a6d908 100644 --- a/torchrl/objectives/value/advantages.py +++ b/torchrl/objectives/value/advantages.py @@ -13,23 +13,29 @@ from typing import Callable, List, Union import torch -from tensordict import TensorDictBase +from tensordict import is_tensor_collection, TensorDictBase from tensordict.nn import ( - CompositeDistribution, + composite_lp_aggregate, dispatch, ProbabilisticTensorDictModule, + set_composite_lp_aggregate, set_skip_existing, TensorDictModule, TensorDictModuleBase, ) from tensordict.nn.probabilistic import interaction_type -from tensordict.utils import NestedKey +from tensordict.utils import NestedKey, unravel_key from torch import Tensor from torchrl._utils import RL_WARNINGS from torchrl.envs.utils import step_mdp -from torchrl.objectives.utils import _vmap_func, hold_out_net, RANDOM_MODULE_LIST +from torchrl.objectives.utils import ( + _maybe_get_or_select, + _vmap_func, + hold_out_net, + RANDOM_MODULE_LIST, +) from torchrl.objectives.value.functional import ( generalized_advantage_estimate, td0_return_estimate, @@ -83,16 +89,9 @@ def _call_actor_net( log_prob_key: NestedKey, ): dist = actor_net.get_dist(data.select(*actor_net.in_keys, strict=False)) - if isinstance(dist, CompositeDistribution): - kwargs = { - "aggregate_probabilities": True, - "inplace": False, - "include_sum": False, - } - else: - kwargs = {} s = actor_net._dist_sample(dist, interaction_type=interaction_type()) - return dist.log_prob(s, **kwargs) + with set_composite_lp_aggregate(True): + return dist.log_prob(s) class ValueEstimatorBase(TensorDictModuleBase): @@ -131,7 +130,9 @@ class _AcceptedKeys: that indicates the number of steps to the next observation. Defaults to ``"steps_to_next_obs"``. sample_log_prob (NestedKey): The key in the input tensordict that - indicates the log probability of the sampled action. Defaults to ``"sample_log_prob"``. + indicates the log probability of the sampled action. + Defaults to ``"sample_log_prob"`` when :func:`~tensordict.nn.composite_lp_aggregate` returns `True`, + `"action_log_prob"` otherwise. """ advantage: NestedKey = "advantage" @@ -141,10 +142,17 @@ class _AcceptedKeys: done: NestedKey = "done" terminated: NestedKey = "terminated" steps_to_next_obs: NestedKey = "steps_to_next_obs" - sample_log_prob: NestedKey = "sample_log_prob" + sample_log_prob: NestedKey | None = None + + def __post_init__(self): + if self.sample_log_prob is None: + if composite_lp_aggregate(nowarn=True): + self.sample_log_prob = "sample_log_prob" + else: + self.sample_log_prob = "action_log_prob" + default_keys = _AcceptedKeys tensor_keys: _AcceptedKeys - default_keys = _AcceptedKeys() value_network: Union[TensorDictModule, Callable] _vmap_randomness = None @@ -294,13 +302,18 @@ def out_keys(self): def set_keys(self, **kwargs) -> None: """Set tensordict key names.""" - for key, value in kwargs.items(): - if not isinstance(value, (str, tuple)): + for key, value in list(kwargs.items()): + if isinstance(value, list): + value = [unravel_key(k) for k in value] + elif not isinstance(value, (str, tuple)): + if value is None: + raise ValueError("tensordict keys cannot be None") raise ValueError( f"key name must be of type NestedKey (Union[str, Tuple[str]]) but got {type(value)}" ) - if value is None: - raise ValueError("tensordict keys cannot be None") + else: + value = unravel_key(value) + if key not in self._AcceptedKeys.__dict__: raise KeyError( f"{key} is not an accepted tensordict key for advantages" @@ -313,8 +326,9 @@ def set_keys(self, **kwargs) -> None: raise KeyError( f"value key '{value}' not found in value network out_keys {self.value_network.out_keys}" ) + kwargs[key] = value if self._tensor_keys is None: - conf = asdict(self.default_keys) + conf = asdict(self.default_keys()) conf.update(self.dep_keys) else: conf = asdict(self._tensor_keys) @@ -1766,12 +1780,11 @@ def forward( value = tensordict.get(self.tensor_keys.value) next_value = tensordict.get(("next", self.tensor_keys.value)) - # Make sure we have the log prob computed at collection time - if self.tensor_keys.sample_log_prob not in tensordict.keys(): - raise ValueError( - f"Expected {self.tensor_keys.sample_log_prob} to be in tensordict" - ) - log_mu = tensordict.get(self.tensor_keys.sample_log_prob).view_as(value) + lp = _maybe_get_or_select(tensordict, self.tensor_keys.sample_log_prob) + if is_tensor_collection(lp): + # Sum all values to match the batch size + lp = lp.sum(dim="feature", reduce=True) + log_mu = lp.view_as(value) # Compute log prob with current policy with hold_out_net(self.actor_network):