From a3d2b680c5012dcc646c9f2bb88865db5dac43cd Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 18 Dec 2024 18:30:52 +0000 Subject: [PATCH 1/7] Update [ghstack-poisoned] --- torchrl/objectives/ppo.py | 99 ++++++++++++++++++++++++++----------- torchrl/objectives/utils.py | 27 +++++++++- 2 files changed, 95 insertions(+), 31 deletions(-) diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index eb9a916dfc1..c35bd8e818d 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -8,7 +8,7 @@ from copy import deepcopy from dataclasses import dataclass -from typing import Tuple +from typing import List, Tuple import torch from tensordict import ( @@ -16,6 +16,7 @@ TensorDict, TensorDictBase, TensorDictParams, + unravel_key, ) from tensordict.nn import ( CompositeDistribution, @@ -33,6 +34,8 @@ _cache_values, _clip_value_loss, _GAMMA_LMBDA_DEPREC_ERROR, + _maybe_add_or_extend_key, + _maybe_get_or_select, _reduce, _sum_td_features, default_value_kwargs, @@ -67,7 +70,10 @@ class PPOLoss(LossModule): Args: actor_network (ProbabilisticTensorDictSequential): policy operator. - critic_network (ValueOperator): value operator. + Typically a :class:`~tensordict.nn.ProbabilisticTensorDictSequential` subclass taking observations + as input and outputting an action (or actions) as well as its log-probability value. + critic_network (ValueOperator): value operator. The critic will usually take the observations as input + and return a scalar value (``state_value`` by default) in the output keys. Keyword Args: entropy_bonus (bool, optional): if ``True``, an entropy bonus will be added to the @@ -267,16 +273,16 @@ class _AcceptedKeys: Will be used for the underlying value estimator Defaults to ``"value_target"``. value (NestedKey): The input tensordict key where the state value is expected. Will be used for the underlying value estimator. Defaults to ``"state_value"``. - sample_log_prob (NestedKey): The input tensordict key where the + sample_log_prob (NestedKey or list of nested keys): The input tensordict key where the sample log probability is expected. Defaults to ``"sample_log_prob"``. - action (NestedKey): The input tensordict key where the action is expected. + action (NestedKey or list of nested keys): The input tensordict key where the action is expected. Defaults to ``"action"``. - reward (NestedKey): The input tensordict key where the reward is expected. + reward (NestedKey or list of nested keys): The input tensordict key where the reward is expected. Will be used for the underlying value estimator. Defaults to ``"reward"``. - done (NestedKey): The key in the input TensorDict that indicates + done (NestedKey or list of nested keys): The key in the input TensorDict that indicates whether a trajectory is done. Will be used for the underlying value estimator. Defaults to ``"done"``. - terminated (NestedKey): The key in the input TensorDict that indicates + terminated (NestedKey or list of nested keys): The key in the input TensorDict that indicates whether a trajectory is terminated. Will be used for the underlying value estimator. Defaults to ``"terminated"``. """ @@ -284,11 +290,11 @@ class _AcceptedKeys: advantage: NestedKey = "advantage" value_target: NestedKey = "value_target" value: NestedKey = "state_value" - sample_log_prob: NestedKey = "sample_log_prob" - action: NestedKey = "action" - reward: NestedKey = "reward" - done: NestedKey = "done" - terminated: NestedKey = "terminated" + sample_log_prob: NestedKey | List[NestedKey] = "sample_log_prob" + action: NestedKey | List[NestedKey] = "action" + reward: NestedKey | List[NestedKey] = "reward" + done: NestedKey | List[NestedKey] = "done" + terminated: NestedKey | List[NestedKey] = "terminated" default_keys = _AcceptedKeys() default_value_estimator = ValueEstimators.GAE @@ -369,7 +375,7 @@ def __init__( try: device = next(self.parameters()).device - except AttributeError: + except (AttributeError, StopIteration): device = torch.device("cpu") self.register_buffer("entropy_coef", torch.tensor(entropy_coef, device=device)) @@ -409,15 +415,36 @@ def functional(self): def _set_in_keys(self): keys = [ - self.tensor_keys.action, - self.tensor_keys.sample_log_prob, - ("next", self.tensor_keys.reward), - ("next", self.tensor_keys.done), - ("next", self.tensor_keys.terminated), *self.actor_network.in_keys, *[("next", key) for key in self.actor_network.in_keys], *self.critic_network.in_keys, ] + + if isinstance(self.tensor_keys.action, NestedKey): + keys.append(self.tensor_keys.action) + else: + keys.extend(self.tensor_keys.action) + + if isinstance(self.tensor_keys.sample_log_prob, NestedKey): + keys.append(self.tensor_keys.sample_log_prob) + else: + keys.extend(self.tensor_keys.sample_log_prob) + + if isinstance(self.tensor_keys.reward, NestedKey): + keys.append(unravel_key(("next", self.tensor_keys.reward))) + else: + keys.extend([unravel_key(("next", k)) for k in self.tensor_keys.reward]) + + if isinstance(self.tensor_keys.done, NestedKey): + keys.append(unravel_key(("next", self.tensor_keys.done))) + else: + keys.extend([unravel_key(("next", k)) for k in self.tensor_keys.done]) + + if isinstance(self.tensor_keys.terminated, NestedKey): + keys.append(unravel_key(("next", self.tensor_keys.terminated))) + else: + keys.extend([unravel_key(("next", k)) for k in self.tensor_keys.terminated]) + self._in_keys = list(set(keys)) @property @@ -472,25 +499,38 @@ def get_entropy_bonus(self, dist: d.Distribution) -> torch.Tensor: if is_tensor_collection(entropy): entropy = _sum_td_features(entropy) except NotImplementedError: - x = dist.rsample((self.samples_mc_entropy,)) + if getattr(dist, "has_rsample", False): + x = dist.rsample((self.samples_mc_entropy,)) + else: + x = dist.sample((self.samples_mc_entropy,)) log_prob = dist.log_prob(x) - if is_tensor_collection(log_prob): + + if is_tensor_collection(log_prob) and isinstance( + self.tensor_keys.sample_log_prob, NestedKey + ): log_prob = log_prob.get(self.tensor_keys.sample_log_prob) + else: + log_prob = log_prob.select(*self.tensor_keys.sample_log_prob) + entropy = -log_prob.mean(0) return entropy.unsqueeze(-1) def _log_weight( self, tensordict: TensorDictBase ) -> Tuple[torch.Tensor, d.Distribution]: + # current log_prob of actions - action = tensordict.get(self.tensor_keys.action) + action = _maybe_get_or_select(tensordict, self.tensor_keys.action) with self.actor_network_params.to_module( self.actor_network ) if self.functional else contextlib.nullcontext(): dist = self.actor_network.get_dist(tensordict) - prev_log_prob = tensordict.get(self.tensor_keys.sample_log_prob) + prev_log_prob = _maybe_get_or_select( + tensordict, self.tensor_keys.sample_log_prob + ) + if prev_log_prob.requires_grad: raise RuntimeError( f"tensordict stored {self.tensor_keys.sample_log_prob} requires grad." @@ -513,8 +553,8 @@ def _log_weight( else: is_composite = False kwargs = {} - log_prob = dist.log_prob(tensordict, **kwargs) - if is_composite and not isinstance(prev_log_prob, TensorDict): + log_prob: TensorDictBase = dist.log_prob(tensordict, **kwargs) + if is_composite and not is_tensor_collection(prev_log_prob): log_prob = _sum_td_features(log_prob) log_prob.view_as(prev_log_prob) @@ -1088,15 +1128,16 @@ def __init__( def _set_in_keys(self): keys = [ - self.tensor_keys.action, - self.tensor_keys.sample_log_prob, - ("next", self.tensor_keys.reward), - ("next", self.tensor_keys.done), - ("next", self.tensor_keys.terminated), *self.actor_network.in_keys, *[("next", key) for key in self.actor_network.in_keys], *self.critic_network.in_keys, ] + _maybe_add_or_extend_key(keys, self.tensor_keys.action) + _maybe_add_or_extend_key(keys, self.tensor_keys.sample_log_prob) + _maybe_add_or_extend_key(keys, self.tensor_keys.reward, "next") + _maybe_add_or_extend_key(keys, self.tensor_keys.done, "next") + _maybe_add_or_extend_key(keys, self.tensor_keys.terminated, "next") + # Get the parameter keys from the actor dist actor_dist_module = None for module in self.actor_network.modules(): diff --git a/torchrl/objectives/utils.py b/torchrl/objectives/utils.py index 9c46fc98262..3e0b97de710 100644 --- a/torchrl/objectives/utils.py +++ b/torchrl/objectives/utils.py @@ -8,10 +8,10 @@ import re import warnings from enum import Enum -from typing import Iterable, Optional, Union +from typing import Iterable, List, Optional, Union import torch -from tensordict import TensorDict, TensorDictBase +from tensordict import NestedKey, TensorDict, TensorDictBase, unravel_key from tensordict.nn import TensorDictModule from torch import nn, Tensor from torch.nn import functional as F @@ -620,3 +620,26 @@ def group_optimizers(*optimizers: torch.optim.Optimizer) -> torch.optim.Optimize def _sum_td_features(data: TensorDictBase) -> torch.Tensor: # Sum all features and return a tensor return data.sum(dim="feature", reduce=True) + + +def _maybe_get_or_select(td, key_or_keys): + if isinstance(key_or_keys, (str, tuple)): + return td.get(key_or_keys) + return td.select(*key_or_keys) + + +def _maybe_add_or_extend_key( + tensor_keys: List[NestedKey], + key_or_list_of_keys: NestedKey | List[NestedKey], + prefix: NestedKey = None, +): + if prefix is not None: + if isinstance(key_or_list_of_keys, NestedKey): + tensor_keys.append(unravel_key((prefix, key_or_list_of_keys))) + else: + tensor_keys.extend([unravel_key((prefix, k)) for k in key_or_list_of_keys]) + return + if isinstance(key_or_list_of_keys, NestedKey): + tensor_keys.append(key_or_list_of_keys) + else: + tensor_keys.extend(key_or_list_of_keys) From c7c90213978c30525543a48975511093aecfbec5 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 9 Jan 2025 15:23:19 +0000 Subject: [PATCH 2/7] Update [ghstack-poisoned] --- examples/agents/composite_ppo.py | 203 +++++++++++++++++++++++++ test/test_cost.py | 109 ++++++++++++- torchrl/objectives/ppo.py | 117 +++++++------- torchrl/objectives/value/advantages.py | 36 +++-- 4 files changed, 393 insertions(+), 72 deletions(-) create mode 100644 examples/agents/composite_ppo.py diff --git a/examples/agents/composite_ppo.py b/examples/agents/composite_ppo.py new file mode 100644 index 00000000000..d75ce3218b3 --- /dev/null +++ b/examples/agents/composite_ppo.py @@ -0,0 +1,203 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +Multi-head agent and PPO loss +============================= + +This example demonstrates how to use TorchRL to create a multi-head agent with three separate distributions +(Gamma, Kumaraswamy, and Mixture) and train it using Proximal Policy Optimization (PPO) losses. + +The code first defines a module `make_params` that extracts the parameters of the distributions from an input tensordict. +It then creates a `dist_constructor` function that takes these parameters as input and outputs a CompositeDistribution +object containing the three distributions. + +The policy is defined as a ProbabilisticTensorDictSequential module that reads an observation, casts it to parameters, +creates a distribution from these parameters, and samples from the distribution to output multiple actions. + +The example tests the policy with fake data across three different PPO losses: PPOLoss, ClipPPOLoss, and KLPENPPOLoss. + +Note that the `log_prob` method of the CompositeDistribution object can return either an aggregated tensor or a +fine-grained tensordict with individual log-probabilities, depending on the value of the `aggregate_probabilities` +argument. The PPO loss modules are designed to handle both cases, and will default to `aggregate_probabilities=False` +if not specified. + +In particular, if `aggregate_probabilities=False` and `include_sum=True`, the summed log-probs will also be included in +the output tensordict. However, since we have access to the individual log-probs, this feature is not typically used. + +""" + +import functools + +import torch +from tensordict import TensorDict +from tensordict.nn import ( + CompositeDistribution, + InteractionType, + ProbabilisticTensorDictModule as Prob, + ProbabilisticTensorDictSequential as ProbSeq, + TensorDictModule as Mod, + TensorDictSequential as Seq, + WrapModule as Wrap, +) +from torch import distributions as d +from torchrl.objectives import ClipPPOLoss, KLPENPPOLoss, PPOLoss + +make_params = Mod( + lambda: ( + torch.ones(4), + torch.ones(4), + torch.ones(4, 2), + torch.ones(4, 2), + torch.ones(4, 10) / 10, + torch.zeros(4, 10), + torch.ones(4, 10), + ), + in_keys=[], + out_keys=[ + ("params", "gamma", "concentration"), + ("params", "gamma", "rate"), + ("params", "Kumaraswamy", "concentration0"), + ("params", "Kumaraswamy", "concentration1"), + ("params", "mixture", "logits"), + ("params", "mixture", "loc"), + ("params", "mixture", "scale"), + ], +) + + +def mixture_constructor(logits, loc, scale): + return d.MixtureSameFamily( + d.Categorical(logits=logits), d.Normal(loc=loc, scale=scale) + ) + + +# ============================================================================= +# Example 0: aggregate_probabilities=None (default) =========================== + +dist_constructor = functools.partial( + CompositeDistribution, + distribution_map={ + "gamma": d.Gamma, + "Kumaraswamy": d.Kumaraswamy, + "mixture": mixture_constructor, + }, + name_map={ + "gamma": ("agent0", "action"), + "Kumaraswamy": ("agent1", "action"), + "mixture": ("agent2", "action"), + }, + aggregate_probabilities=None, +) + + +policy = ProbSeq( + make_params, + Prob( + in_keys=["params"], + out_keys=[("agent0", "action"), ("agent1", "action"), ("agent2", "action")], + distribution_class=dist_constructor, + return_log_prob=True, + default_interaction_type=InteractionType.RANDOM, + ), +) + +td = policy(TensorDict(batch_size=[4])) +print("0. result of policy call", td) + +dist = policy.get_dist(td) +log_prob = dist.log_prob( + td, aggregate_probabilities=False, inplace=False, include_sum=False +) +print("0. non-aggregated log-prob") + +# We can also get the log-prob from the policy directly +log_prob = policy.log_prob( + td, aggregate_probabilities=False, inplace=False, include_sum=False +) +print("0. non-aggregated log-prob (from policy)") + +# Build a dummy value operator +value_operator = Seq( + Wrap( + lambda td: td.set("state_value", torch.ones((*td.shape, 1))), + out_keys=["state_value"], + ) +) + +# Create fake data +data = policy(TensorDict(batch_size=[4])) +data.set( + "next", + TensorDict(reward=torch.randn(4, 1), done=torch.zeros(4, 1, dtype=torch.bool)), +) + +# Instantiate the loss +for loss_cls in (PPOLoss, ClipPPOLoss, KLPENPPOLoss): + ppo = loss_cls(policy, value_operator) + + # Keys are not the default ones - there is more than one action + ppo.set_keys( + action=[("agent0", "action"), ("agent1", "action"), ("agent2", "action")], + sample_log_prob=[ + ("agent0", "action_log_prob"), + ("agent1", "action_log_prob"), + ("agent2", "action_log_prob"), + ], + ) + + # Get the loss values + loss_vals = ppo(data) + print("0. ", loss_cls, loss_vals) + + +# =================================================================== +# Example 1: aggregate_probabilities=True =========================== + +dist_constructor.keywords["aggregate_probabilities"] = True + +td = policy(TensorDict(batch_size=[4])) +print("1. result of policy call", td) + +# Instantiate the loss +for loss_cls in (PPOLoss, ClipPPOLoss, KLPENPPOLoss): + ppo = loss_cls(policy, value_operator) + + # Keys are not the default ones - there is more than one action. No need to indicate the sample-log-prob key, since + # there is only one. + ppo.set_keys( + action=[("agent0", "action"), ("agent1", "action"), ("agent2", "action")] + ) + + # Get the loss values + loss_vals = ppo(data) + print("1. ", loss_cls, loss_vals) + + +# =================================================================== +# Example 2: aggregate_probabilities=False =========================== + +dist_constructor.keywords["aggregate_probabilities"] = False + +td = policy(TensorDict(batch_size=[4])) +print("2. result of policy call", td) + +# Instantiate the loss +for loss_cls in (PPOLoss, ClipPPOLoss, KLPENPPOLoss): + ppo = loss_cls(policy, value_operator) + + # Keys are not the default ones - there is more than one action + ppo.set_keys( + action=[("agent0", "action"), ("agent1", "action"), ("agent2", "action")], + sample_log_prob=[ + ("agent0", "action_log_prob"), + ("agent1", "action_log_prob"), + ("agent2", "action_log_prob"), + ], + ) + + # Get the loss values + loss_vals = ppo(data) + print("2. ", loss_cls, loss_vals) diff --git a/test/test_cost.py b/test/test_cost.py index 1f191e41db6..7c7c97eedfc 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -34,6 +34,7 @@ TensorDictModule as Mod, TensorDictSequential, TensorDictSequential as Seq, + WrapModule, ) from tensordict.nn.utils import Buffer from tensordict.utils import unravel_key @@ -8864,9 +8865,7 @@ def test_ppo_tensordict_keys_run( @pytest.mark.parametrize("terminated_key", ["terminated", "terminated2"]) @pytest.mark.parametrize( "composite_action_dist", - [ - False, - ], + [False], ) def test_ppo_notensordict( self, @@ -9060,6 +9059,110 @@ def test_ppo_value_clipping( loss = loss_fn(td) assert "loss_critic" in loss.keys() + def test_ppo_composite_dists(self): + d = torch.distributions + + make_params = TensorDictModule( + lambda: ( + torch.ones(4), + torch.ones(4), + torch.ones(4, 2), + torch.ones(4, 2), + torch.ones(4, 10) / 10, + torch.zeros(4, 10), + torch.ones(4, 10), + ), + in_keys=[], + out_keys=[ + ("params", "gamma", "concentration"), + ("params", "gamma", "rate"), + ("params", "Kumaraswamy", "concentration0"), + ("params", "Kumaraswamy", "concentration1"), + ("params", "mixture", "logits"), + ("params", "mixture", "loc"), + ("params", "mixture", "scale"), + ], + ) + + def mixture_constructor(logits, loc, scale): + return d.MixtureSameFamily( + d.Categorical(logits=logits), d.Normal(loc=loc, scale=scale) + ) + + dist_constructor = functools.partial( + CompositeDistribution, + distribution_map={ + "gamma": d.Gamma, + "Kumaraswamy": d.Kumaraswamy, + "mixture": mixture_constructor, + }, + name_map={ + "gamma": ("agent0", "action"), + "Kumaraswamy": ("agent1", "action"), + "mixture": ("agent2", "action"), + }, + aggregate_probabilities=False, + include_sum=False, + inplace=True, + ) + policy = ProbSeq( + make_params, + ProbabilisticTensorDictModule( + in_keys=["params"], + out_keys=[ + ("agent0", "action"), + ("agent1", "action"), + ("agent2", "action"), + ], + distribution_class=dist_constructor, + return_log_prob=True, + default_interaction_type=InteractionType.RANDOM, + ), + ) + # We want to make sure there is no warning + td = policy(TensorDict(batch_size=[4])) + assert isinstance( + policy.get_dist(td).log_prob( + td, aggregate_probabilities=False, inplace=False, include_sum=False + ), + TensorDict, + ) + assert isinstance( + policy.log_prob( + td, aggregate_probabilities=False, inplace=False, include_sum=False + ), + TensorDict, + ) + value_operator = Seq( + WrapModule( + lambda td: td.set("state_value", torch.ones((*td.shape, 1))), + out_keys=["state_value"], + ) + ) + for cls in (PPOLoss, ClipPPOLoss, KLPENPPOLoss): + data = policy(TensorDict(batch_size=[4])) + data.set( + "next", + TensorDict( + reward=torch.randn(4, 1), done=torch.zeros(4, 1, dtype=torch.bool) + ), + ) + ppo = cls(policy, value_operator) + ppo.set_keys( + action=[ + ("agent0", "action"), + ("agent1", "action"), + ("agent2", "action"), + ], + sample_log_prob=[ + ("agent0", "action_log_prob"), + ("agent1", "action_log_prob"), + ("agent2", "action_log_prob"), + ], + ) + loss = ppo(data) + loss.sum(reduce=True) + class TestA2C(LossModuleTestBase): seed = 0 diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index c35bd8e818d..9e833d0518b 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -16,7 +16,6 @@ TensorDict, TensorDictBase, TensorDictParams, - unravel_key, ) from tensordict.nn import ( CompositeDistribution, @@ -414,36 +413,15 @@ def functional(self): return self._functional def _set_in_keys(self): - keys = [ - *self.actor_network.in_keys, - *[("next", key) for key in self.actor_network.in_keys], - *self.critic_network.in_keys, - ] - - if isinstance(self.tensor_keys.action, NestedKey): - keys.append(self.tensor_keys.action) - else: - keys.extend(self.tensor_keys.action) - - if isinstance(self.tensor_keys.sample_log_prob, NestedKey): - keys.append(self.tensor_keys.sample_log_prob) - else: - keys.extend(self.tensor_keys.sample_log_prob) - - if isinstance(self.tensor_keys.reward, NestedKey): - keys.append(unravel_key(("next", self.tensor_keys.reward))) - else: - keys.extend([unravel_key(("next", k)) for k in self.tensor_keys.reward]) - - if isinstance(self.tensor_keys.done, NestedKey): - keys.append(unravel_key(("next", self.tensor_keys.done))) - else: - keys.extend([unravel_key(("next", k)) for k in self.tensor_keys.done]) - - if isinstance(self.tensor_keys.terminated, NestedKey): - keys.append(unravel_key(("next", self.tensor_keys.terminated))) - else: - keys.extend([unravel_key(("next", k)) for k in self.tensor_keys.terminated]) + keys = [] + _maybe_add_or_extend_key(keys, self.actor_network.in_keys) + _maybe_add_or_extend_key(keys, self.actor_network.in_keys, "next") + _maybe_add_or_extend_key(keys, self.critic_network.in_keys) + _maybe_add_or_extend_key(keys, self.tensor_keys.action) + _maybe_add_or_extend_key(keys, self.tensor_keys.sample_log_prob) + _maybe_add_or_extend_key(keys, self.tensor_keys.reward, "next") + _maybe_add_or_extend_key(keys, self.tensor_keys.done, "next") + _maybe_add_or_extend_key(keys, self.tensor_keys.terminated, "next") self._in_keys = list(set(keys)) @@ -483,6 +461,7 @@ def _forward_value_estimator_keys(self, **kwargs) -> None: reward=self.tensor_keys.reward, done=self.tensor_keys.done, terminated=self.tensor_keys.terminated, + sample_log_prob=self.tensor_keys.sample_log_prob, ) self._set_in_keys() @@ -490,29 +469,34 @@ def reset(self) -> None: pass def get_entropy_bonus(self, dist: d.Distribution) -> torch.Tensor: + if isinstance(dist, CompositeDistribution): + aggregate = dist.aggregate_probabilities + if aggregate is None: + aggregate = False + include_sum = dist.include_sum + if include_sum is None: + include_sum = False + kwargs = {"aggregate_probabilities": aggregate, "include_sum": include_sum} + else: + kwargs = {} try: - if isinstance(dist, CompositeDistribution): - kwargs = {"aggregate_probabilities": False, "include_sum": False} - else: - kwargs = {} entropy = dist.entropy(**kwargs) - if is_tensor_collection(entropy): - entropy = _sum_td_features(entropy) except NotImplementedError: if getattr(dist, "has_rsample", False): x = dist.rsample((self.samples_mc_entropy,)) else: x = dist.sample((self.samples_mc_entropy,)) - log_prob = dist.log_prob(x) + log_prob = dist.log_prob(x, **kwargs) - if is_tensor_collection(log_prob) and isinstance( - self.tensor_keys.sample_log_prob, NestedKey - ): - log_prob = log_prob.get(self.tensor_keys.sample_log_prob) - else: - log_prob = log_prob.select(*self.tensor_keys.sample_log_prob) + if is_tensor_collection(log_prob): + if isinstance(self.tensor_keys.sample_log_prob, NestedKey): + log_prob = log_prob.get(self.tensor_keys.sample_log_prob) + else: + log_prob = log_prob.select(*self.tensor_keys.sample_log_prob) entropy = -log_prob.mean(0) + if is_tensor_collection(entropy): + entropy = _sum_td_features(entropy) return entropy.unsqueeze(-1) def _log_weight( @@ -545,21 +529,33 @@ def _log_weight( else: if isinstance(dist, CompositeDistribution): is_composite = True + aggregate = dist.aggregate_probabilities + if aggregate is None: + aggregate = False + include_sum = dist.include_sum + if include_sum is None: + include_sum = False kwargs = { "inplace": False, - "aggregate_probabilities": False, - "include_sum": False, + "aggregate_probabilities": aggregate, + "include_sum": include_sum, } else: is_composite = False kwargs = {} log_prob: TensorDictBase = dist.log_prob(tensordict, **kwargs) - if is_composite and not is_tensor_collection(prev_log_prob): + if ( + is_composite + and not is_tensor_collection(prev_log_prob) + and is_tensor_collection(log_prob) + ): log_prob = _sum_td_features(log_prob) log_prob.view_as(prev_log_prob) log_weight = (log_prob - prev_log_prob).unsqueeze(-1) kl_approx = (prev_log_prob - log_prob).unsqueeze(-1) + if is_tensor_collection(kl_approx): + kl_approx = _sum_td_features(kl_approx) return log_weight, dist, kl_approx @@ -933,6 +929,8 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: gain2 = ratio * advantage gain = torch.stack([gain1, gain2], -1).min(dim=-1)[0] + if is_tensor_collection(gain): + gain = gain.sum(reduce=True) td_out = TensorDict({"loss_objective": -gain}, batch_size=[]) td_out.set("clip_fraction", clip_fraction) @@ -940,7 +938,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: entropy = self.get_entropy_bonus(dist) td_out.set("entropy", entropy.detach().mean()) # for logging td_out.set("kl_approx", kl_approx.detach().mean()) # for logging - td_out.set("loss_entropy", -self.entropy_coef * entropy) + td_out.set("loss_entropy", -self.entropy_coef * entropy.mean()) if self.critic_coef is not None: loss_critic, value_clip_fraction = self.loss_critic(tensordict) td_out.set("loss_critic", loss_critic) @@ -1127,11 +1125,10 @@ def __init__( self.samples_mc_kl = samples_mc_kl def _set_in_keys(self): - keys = [ - *self.actor_network.in_keys, - *[("next", key) for key in self.actor_network.in_keys], - *self.critic_network.in_keys, - ] + keys = [] + _maybe_add_or_extend_key(keys, self.actor_network.in_keys) + _maybe_add_or_extend_key(keys, self.actor_network.in_keys, "next") + _maybe_add_or_extend_key(keys, self.critic_network.in_keys) _maybe_add_or_extend_key(keys, self.tensor_keys.action) _maybe_add_or_extend_key(keys, self.tensor_keys.sample_log_prob) _maybe_add_or_extend_key(keys, self.tensor_keys.reward, "next") @@ -1197,6 +1194,8 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict: advantage = (advantage - loc) / scale log_weight, dist, kl_approx = self._log_weight(tensordict_copy) neg_loss = log_weight.exp() * advantage + if is_tensor_collection(neg_loss): + neg_loss = _sum_td_features(neg_loss) with self.actor_network_params.to_module( self.actor_network @@ -1207,16 +1206,22 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict: except NotImplementedError: x = previous_dist.sample((self.samples_mc_kl,)) if isinstance(previous_dist, CompositeDistribution): + aggregate = previous_dist.aggregate_probabilities + if aggregate is None: + aggregate = False + include_sum = previous_dist.include_sum + if include_sum is None: + include_sum = False kwargs = { - "aggregate_probabilities": False, + "aggregate_probabilities": aggregate, "inplace": False, - "include_sum": False, + "include_sum": include_sum, } else: kwargs = {} previous_log_prob = previous_dist.log_prob(x, **kwargs) current_log_prob = current_dist.log_prob(x, **kwargs) - if is_tensor_collection(current_log_prob): + if is_tensor_collection(previous_log_prob): previous_log_prob = _sum_td_features(previous_log_prob) current_log_prob = _sum_td_features(current_log_prob) kl = (previous_log_prob - current_log_prob).mean(0) diff --git a/torchrl/objectives/value/advantages.py b/torchrl/objectives/value/advantages.py index 3b08780e24c..fa05c8860a6 100644 --- a/torchrl/objectives/value/advantages.py +++ b/torchrl/objectives/value/advantages.py @@ -13,7 +13,7 @@ from typing import Callable, List, Union import torch -from tensordict import TensorDictBase +from tensordict import is_tensor_collection, TensorDictBase from tensordict.nn import ( CompositeDistribution, dispatch, @@ -23,13 +23,18 @@ TensorDictModuleBase, ) from tensordict.nn.probabilistic import interaction_type -from tensordict.utils import NestedKey +from tensordict.utils import NestedKey, unravel_key from torch import Tensor from torchrl._utils import RL_WARNINGS from torchrl.envs.utils import step_mdp -from torchrl.objectives.utils import _vmap_func, hold_out_net, RANDOM_MODULE_LIST +from torchrl.objectives.utils import ( + _maybe_get_or_select, + _vmap_func, + hold_out_net, + RANDOM_MODULE_LIST, +) from torchrl.objectives.value.functional import ( generalized_advantage_estimate, td0_return_estimate, @@ -293,13 +298,18 @@ def out_keys(self): def set_keys(self, **kwargs) -> None: """Set tensordict key names.""" - for key, value in kwargs.items(): - if not isinstance(value, (str, tuple)): + for key, value in list(kwargs.items()): + if isinstance(value, list): + value = [unravel_key(k) for k in value] + elif not isinstance(value, (str, tuple)): + if value is None: + raise ValueError("tensordict keys cannot be None") raise ValueError( f"key name must be of type NestedKey (Union[str, Tuple[str]]) but got {type(value)}" ) - if value is None: - raise ValueError("tensordict keys cannot be None") + else: + value = unravel_key(value) + if key not in self._AcceptedKeys.__dict__: raise KeyError( f"{key} is not an accepted tensordict key for advantages" @@ -312,6 +322,7 @@ def set_keys(self, **kwargs) -> None: raise KeyError( f"value key '{value}' not found in value network out_keys {self.value_network.out_keys}" ) + kwargs[key] = value if self._tensor_keys is None: conf = asdict(self.default_keys) conf.update(self.dep_keys) @@ -1765,12 +1776,11 @@ def forward( value = tensordict.get(self.tensor_keys.value) next_value = tensordict.get(("next", self.tensor_keys.value)) - # Make sure we have the log prob computed at collection time - if self.tensor_keys.sample_log_prob not in tensordict.keys(): - raise ValueError( - f"Expected {self.tensor_keys.sample_log_prob} to be in tensordict" - ) - log_mu = tensordict.get(self.tensor_keys.sample_log_prob).view_as(value) + lp = _maybe_get_or_select(tensordict, self.tensor_keys.sample_log_prob) + if is_tensor_collection(lp): + # Sum all values to match the batch size + lp = lp.sum(dim="feature", reduce=True) + log_mu = lp.view_as(value) # Compute log prob with current policy with hold_out_net(self.actor_network): From 86ab9b7d7bbe8f24759642798acc6b45c25cd335 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 10 Jan 2025 10:02:49 +0000 Subject: [PATCH 3/7] Update [ghstack-poisoned] --- torchrl/objectives/ppo.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index 079a1efa92c..53a7bfae5df 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -27,6 +27,7 @@ from tensordict.utils import NestedKey from torch import distributions as d +from torchrl._utils import _replace_last from torchrl.objectives.common import LossModule from torchrl.objectives.utils import ( @@ -1267,3 +1268,30 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict: def reset(self) -> None: self.beta = self._beta_init + + +def _make_lp_get_error(tensor_keys, log_prob, err): + result = ( + f"The sample log probability key (tensor_keys.sample_log_prob={tensor_keys.sample_log_prob}) does " + f"not appear in the log-prob tensordict with keys {list(log_prob.keys(True, True))}. " + ) + # now check if we can substitute the actions with action_log_prob and retrieve the log-probs + action_keys = tensor_keys.action + if isinstance(action_keys, list): + has_all_log_probs = True + log_prob_keys = [] + for action_key in action_keys: + log_prob_key = _replace_last(action_key, "action_log_prob") + log_prob_keys.append(log_prob_key) + if log_prob_key not in log_prob: + has_all_log_probs = False + break + if has_all_log_probs: + result += ( + f"The action keys are {action_keys} and all log_prob keys {log_prob_keys} are present in the " + f"log-prob tensordict. Calling `loss.set_keys(sample_log_prob={log_prob_keys})` should resolve " + f"this error." + ) + return KeyError(result) + result += "This is usually due to a missing call to loss.set_keys(sample_log_prob=)." + return KeyError(result) From 14e639d0393df511e07efbdc1ab8ffa474eb2e04 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 10 Jan 2025 15:38:08 +0000 Subject: [PATCH 4/7] Update [ghstack-poisoned] --- examples/agents/composite_actor.py | 6 +++ test/test_cost.py | 60 ++++++++++++++++++++++++++---- torchrl/objectives/ppo.py | 45 +++++++++++++--------- 3 files changed, 86 insertions(+), 25 deletions(-) diff --git a/examples/agents/composite_actor.py b/examples/agents/composite_actor.py index ae08062e084..c7e83095983 100644 --- a/examples/agents/composite_actor.py +++ b/examples/agents/composite_actor.py @@ -50,3 +50,9 @@ def forward(self, x): data = TensorDict({"x": torch.rand(10)}, []) module(data) print(actor(data)) + + +# TODO: +# 1. Use ("action", "action0") + ("action", "action1") vs ("agent0", "action") + ("agent1", "action") +# 2. Must multi-head require an action_key to be a list of keys (I guess so) +# 3. Using maps in the Actor diff --git a/test/test_cost.py b/test/test_cost.py index 7c7c97eedfc..7ee72543ecf 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -7908,11 +7908,11 @@ def _create_mock_actor( obs_dim=3, action_dim=4, device="cpu", - action_key="action", + action_key=None, observation_key="observation", sample_log_prob_key="sample_log_prob", composite_action_dist=False, - aggregate_probabilities=True, + aggregate_probabilities=None, ): # Actor action_spec = Bounded( @@ -7922,13 +7922,17 @@ def _create_mock_actor( action_spec = Composite({action_key: {"action1": action_spec}}) net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor()) if composite_action_dist: + if action_key is None: + action_key = ("action", "action1") + else: + action_key = (action_key, "action1") distribution_class = functools.partial( CompositeDistribution, distribution_map={ "action1": TanhNormal, }, name_map={ - "action1": (action_key, "action1"), + "action1": action_key, }, log_prob_key=sample_log_prob_key, aggregate_probabilities=aggregate_probabilities, @@ -7939,6 +7943,8 @@ def _create_mock_actor( ] actor_in_keys = ["params"] else: + if action_key is None: + action_key = "action" distribution_class = TanhNormal module_out_keys = actor_in_keys = ["loc", "scale"] module = TensorDictModule( @@ -8149,8 +8155,8 @@ def _create_seq_mock_data_ppo( action_dim=4, atoms=None, device="cpu", - sample_log_prob_key="sample_log_prob", - action_key="action", + sample_log_prob_key=None, + action_key=None, composite_action_dist=False, ): # create a tensordict @@ -8172,6 +8178,17 @@ def _create_seq_mock_data_ppo( params_scale = torch.rand_like(action) / 10 loc = params_mean.masked_fill_(~mask.unsqueeze(-1), 0.0) scale = params_scale.masked_fill_(~mask.unsqueeze(-1), 0.0) + if sample_log_prob_key is None: + if composite_action_dist: + sample_log_prob_key = ("action", "action1_log_prob") + else: + sample_log_prob_key = "sample_log_prob" + + if action_key is None: + if composite_action_dist: + action_key = ("action", "action1") + else: + action_key = "action" td = TensorDict( batch_size=(batch, T), source={ @@ -8183,7 +8200,7 @@ def _create_seq_mock_data_ppo( "reward": reward.masked_fill_(~mask.unsqueeze(-1), 0.0), }, "collector": {"mask": mask}, - action_key: {"action1": action} if composite_action_dist else action, + action_key: action, sample_log_prob_key: ( torch.randn_like(action[..., 1]) / 10 ).masked_fill_(~mask, 0.0), @@ -8263,6 +8280,13 @@ def test_ppo( loss_critic_type="l2", functional=functional, ) + if composite_action_dist: + loss_fn.set_keys( + action=("action", "action1"), + sample_log_prob=[("action", "action1_log_prob")], + ) + if advantage is not None: + advantage.set_keys(sample_log_prob=[("action", "action1_log_prob")]) if advantage is not None: advantage(td) else: @@ -8356,7 +8380,9 @@ def test_ppo_composite_no_aggregate( loss_critic_type="l2", functional=functional, ) + loss_fn.set_keys(action=("action", "action1"), sample_log_prob=[("action", "action1_log_prob")]) if advantage is not None: + advantage.set_keys(sample_log_prob=[("action", "action1_log_prob")]) advantage(td) else: if td_est is not None: @@ -8464,7 +8490,12 @@ def test_ppo_shared(self, loss_class, device, advantage, composite_action_dist): ) if advantage is not None: + if composite_action_dist: + advantage.set_keys(sample_log_prob=[("action", "action1_log_prob")]) advantage(td) + + if composite_action_dist: + loss_fn.set_keys(action=("action", "action1"), sample_log_prob=[("action", "action1_log_prob")]) loss = loss_fn(td) loss_critic = loss["loss_critic"] @@ -8571,7 +8602,14 @@ def test_ppo_shared_seq( ) if advantage is not None: + if composite_action_dist: + advantage.set_keys(sample_log_prob=[("action", "action1_log_prob")]) advantage(td) + + if composite_action_dist: + loss_fn.set_keys(action=("action", "action1"), sample_log_prob=[("action", "action1_log_prob")]) + loss_fn2.set_keys(action=("action", "action1"), sample_log_prob=[("action", "action1_log_prob")]) + loss = loss_fn(td).exclude("entropy") sum(val for key, val in loss.items() if key.startswith("loss_")).backward() @@ -8659,7 +8697,11 @@ def zero_param(p): # assert len(list(floss_fn.parameters())) == 0 with params.to_module(loss_fn): if advantage is not None: + if composite_action_dist: + advantage.set_keys(sample_log_prob=[("action", "action1_log_prob")]) advantage(td) + if composite_action_dist: + loss_fn.set_keys(action=("action", "action1"), sample_log_prob=[("action", "action1_log_prob")]) loss = loss_fn(td) loss_critic = loss["loss_critic"] @@ -8760,8 +8802,8 @@ def test_ppo_tensordict_keys_run( "advantage": "advantage_test", "value_target": "value_target_test", "value": "state_value_test", - "sample_log_prob": "sample_log_prob_test", - "action": "action_test", + "sample_log_prob": ('action_test', 'action1_log_prob') if composite_action_dist else "sample_log_prob_test", + "action": ("action_test", "action") if composite_action_dist else "action_test", } td = self._create_seq_mock_data_ppo( @@ -8809,6 +8851,8 @@ def test_ppo_tensordict_keys_run( raise NotImplementedError loss_fn = loss_class(actor, value, loss_critic_type="l2") + if composite_action_dist: + tensor_keys["sample_log_prob"] = [tensor_keys["sample_log_prob"]] loss_fn.set_keys(**tensor_keys) if advantage is not None: # collect tensordict key names for the advantage module diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index 53a7bfae5df..18d95bcdd7a 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -5,6 +5,7 @@ from __future__ import annotations import contextlib +import warnings from copy import deepcopy from dataclasses import dataclass @@ -531,26 +532,35 @@ def _log_weight( raise RuntimeError( f"tensordict stored {self.tensor_keys.action} requires grad." ) - if isinstance(action, torch.Tensor): + if isinstance(dist, CompositeDistribution): + is_composite = True + aggregate = dist.aggregate_probabilities + if aggregate is None: + aggregate = False + include_sum = dist.include_sum + if include_sum is None: + include_sum = False + kwargs = { + "inplace": False, + "aggregate_probabilities": aggregate, + "include_sum": include_sum, + } + else: + is_composite = False + kwargs = {} + if not is_composite: log_prob = dist.log_prob(action) else: - if isinstance(dist, CompositeDistribution): - is_composite = True - aggregate = dist.aggregate_probabilities - if aggregate is None: - aggregate = False - include_sum = dist.include_sum - if include_sum is None: - include_sum = False - kwargs = { - "inplace": False, - "aggregate_probabilities": aggregate, - "include_sum": include_sum, - } - else: - is_composite = False - kwargs = {} log_prob: TensorDictBase = dist.log_prob(tensordict, **kwargs) + if not is_tensor_collection(prev_log_prob): + # this isn't great, in general multihead actions should have a composite log-prob too + warnings.warn( + "You are using a composite distribution, yet your log-probability is a tensor. " + "This usually happens whenever the CompositeDistribution has aggregate_probabilities=True " + "or include_sum=True. These options should be avoided: leaf log-probs should be written " + "independently and PPO will take care of the aggregation.", + category=UserWarning, + ) if ( is_composite and not is_tensor_collection(prev_log_prob) @@ -559,6 +569,7 @@ def _log_weight( log_prob = _sum_td_features(log_prob) log_prob.view_as(prev_log_prob) + print(log_prob , prev_log_prob) log_weight = (log_prob - prev_log_prob).unsqueeze(-1) kl_approx = (prev_log_prob - log_prob).unsqueeze(-1) if is_tensor_collection(kl_approx): From 30166a0883768cf168f6ccf9bc50566204dfcb78 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 13 Jan 2025 14:42:48 +0000 Subject: [PATCH 5/7] Update [ghstack-poisoned] --- test/test_cost.py | 53 ++++++++++++++++++++++++++------------- torchrl/objectives/ppo.py | 3 +-- 2 files changed, 36 insertions(+), 20 deletions(-) diff --git a/test/test_cost.py b/test/test_cost.py index 7ee72543ecf..a538a8d3418 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -7918,14 +7918,13 @@ def _create_mock_actor( action_spec = Bounded( -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) ) - if composite_action_dist: - action_spec = Composite({action_key: {"action1": action_spec}}) net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor()) if composite_action_dist: if action_key is None: action_key = ("action", "action1") else: action_key = (action_key, "action1") + action_spec = Composite({action_key: {"action1": action_spec}}) distribution_class = functools.partial( CompositeDistribution, distribution_map={ @@ -8380,7 +8379,10 @@ def test_ppo_composite_no_aggregate( loss_critic_type="l2", functional=functional, ) - loss_fn.set_keys(action=("action", "action1"), sample_log_prob=[("action", "action1_log_prob")]) + loss_fn.set_keys( + action=("action", "action1"), + sample_log_prob=[("action", "action1_log_prob")], + ) if advantage is not None: advantage.set_keys(sample_log_prob=[("action", "action1_log_prob")]) advantage(td) @@ -8495,7 +8497,10 @@ def test_ppo_shared(self, loss_class, device, advantage, composite_action_dist): advantage(td) if composite_action_dist: - loss_fn.set_keys(action=("action", "action1"), sample_log_prob=[("action", "action1_log_prob")]) + loss_fn.set_keys( + action=("action", "action1"), + sample_log_prob=[("action", "action1_log_prob")], + ) loss = loss_fn(td) loss_critic = loss["loss_critic"] @@ -8607,8 +8612,14 @@ def test_ppo_shared_seq( advantage(td) if composite_action_dist: - loss_fn.set_keys(action=("action", "action1"), sample_log_prob=[("action", "action1_log_prob")]) - loss_fn2.set_keys(action=("action", "action1"), sample_log_prob=[("action", "action1_log_prob")]) + loss_fn.set_keys( + action=("action", "action1"), + sample_log_prob=[("action", "action1_log_prob")], + ) + loss_fn2.set_keys( + action=("action", "action1"), + sample_log_prob=[("action", "action1_log_prob")], + ) loss = loss_fn(td).exclude("entropy") @@ -8701,7 +8712,10 @@ def zero_param(p): advantage.set_keys(sample_log_prob=[("action", "action1_log_prob")]) advantage(td) if composite_action_dist: - loss_fn.set_keys(action=("action", "action1"), sample_log_prob=[("action", "action1_log_prob")]) + loss_fn.set_keys( + action=("action", "action1"), + sample_log_prob=[("action", "action1_log_prob")], + ) loss = loss_fn(td) loss_critic = loss["loss_critic"] @@ -8791,10 +8805,7 @@ def test_ppo_tensordict_keys(self, loss_class, td_est, composite_action_dist): @pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss)) @pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None)) @pytest.mark.parametrize("td_est", list(ValueEstimators) + [None]) - @pytest.mark.parametrize("composite_action_dist", [True, False]) - def test_ppo_tensordict_keys_run( - self, loss_class, advantage, td_est, composite_action_dist - ): + def test_ppo_tensordict_keys_run(self, loss_class, advantage, td_est): """Test PPO loss module with non-default tensordict keys.""" torch.manual_seed(self.seed) gradient_mode = True @@ -8802,18 +8813,16 @@ def test_ppo_tensordict_keys_run( "advantage": "advantage_test", "value_target": "value_target_test", "value": "state_value_test", - "sample_log_prob": ('action_test', 'action1_log_prob') if composite_action_dist else "sample_log_prob_test", - "action": ("action_test", "action") if composite_action_dist else "action_test", + "sample_log_prob": "sample_log_prob_test", + "action": "action_test", } td = self._create_seq_mock_data_ppo( sample_log_prob_key=tensor_keys["sample_log_prob"], action_key=tensor_keys["action"], - composite_action_dist=composite_action_dist, ) actor = self._create_mock_actor( sample_log_prob_key=tensor_keys["sample_log_prob"], - composite_action_dist=composite_action_dist, action_key=tensor_keys["action"], ) value = self._create_mock_value(out_keys=[tensor_keys["value"]]) @@ -8851,8 +8860,6 @@ def test_ppo_tensordict_keys_run( raise NotImplementedError loss_fn = loss_class(actor, value, loss_critic_type="l2") - if composite_action_dist: - tensor_keys["sample_log_prob"] = [tensor_keys["sample_log_prob"]] loss_fn.set_keys(**tensor_keys) if advantage is not None: # collect tensordict key names for the advantage module @@ -9030,11 +9037,16 @@ def test_ppo_reduction(self, reduction, loss_class, composite_action_dist): reduction=reduction, ) advantage(td) + if composite_action_dist: + loss_fn.set_keys( + action=("action", "action1"), + sample_log_prob=[("action", "action1_log_prob")], + ) loss = loss_fn(td) if reduction == "none": for key in loss.keys(): if key.startswith("loss_"): - assert loss[key].shape == td.shape + assert loss[key].shape == td.shape, key else: for key in loss.keys(): if not key.startswith("loss_"): @@ -9082,6 +9094,11 @@ def test_ppo_value_clipping( clip_value=clip_value, ) advantage(td) + if composite_action_dist: + loss_fn.set_keys( + action=("action", "action1"), + sample_log_prob=[("action", "action1_log_prob")], + ) value = td.pop(loss_fn.tensor_keys.value) diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index 18d95bcdd7a..2412ea62180 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -569,7 +569,6 @@ def _log_weight( log_prob = _sum_td_features(log_prob) log_prob.view_as(prev_log_prob) - print(log_prob , prev_log_prob) log_weight = (log_prob - prev_log_prob).unsqueeze(-1) kl_approx = (prev_log_prob - log_prob).unsqueeze(-1) if is_tensor_collection(kl_approx): @@ -946,7 +945,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: ratio = log_weight_clip.exp() gain2 = ratio * advantage - gain = torch.stack([gain1, gain2], -1).min(dim=-1)[0] + gain = torch.stack([gain1, gain2], -1).min(dim=-1).values if is_tensor_collection(gain): gain = _sum_td_features(gain) td_out = TensorDict({"loss_objective": -gain}, batch_size=[]) From b42dbc60e526daf606c39edcd6047d53490864e0 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 15 Jan 2025 15:07:07 +0000 Subject: [PATCH 6/7] Update [ghstack-poisoned] --- .../unittest/linux_sota/scripts/test_sota.py | 26 ++-- test/test_cost.py | 129 +++++++++++------- torchrl/objectives/deprecated.py | 33 +++-- torchrl/objectives/ppo.py | 55 ++++---- torchrl/objectives/redq.py | 1 + 5 files changed, 145 insertions(+), 99 deletions(-) diff --git a/.github/unittest/linux_sota/scripts/test_sota.py b/.github/unittest/linux_sota/scripts/test_sota.py index b7af381634c..25d1e7a4390 100644 --- a/.github/unittest/linux_sota/scripts/test_sota.py +++ b/.github/unittest/linux_sota/scripts/test_sota.py @@ -188,19 +188,6 @@ ppo.collector.frames_per_batch=16 \ logger.mode=offline \ logger.backend= -""", - "dreamer": """python sota-implementations/dreamer/dreamer.py \ - collector.total_frames=600 \ - collector.init_random_frames=10 \ - collector.frames_per_batch=200 \ - env.n_parallel_envs=1 \ - optimization.optim_steps_per_batch=1 \ - logger.video=False \ - logger.backend=csv \ - replay_buffer.buffer_size=120 \ - replay_buffer.batch_size=24 \ - replay_buffer.batch_length=12 \ - networks.rssm_hidden_dim=17 """, "ddpg-single": """python sota-implementations/ddpg/ddpg.py \ collector.total_frames=48 \ @@ -289,6 +276,19 @@ logger.backend= """, "bandits": """python sota-implementations/bandits/dqn.py --n_steps=100 +""", + "dreamer": """python sota-implementations/dreamer/dreamer.py \ + collector.total_frames=600 \ + collector.init_random_frames=10 \ + collector.frames_per_batch=200 \ + env.n_parallel_envs=1 \ + optimization.optim_steps_per_batch=1 \ + logger.video=False \ + logger.backend=csv \ + replay_buffer.buffer_size=120 \ + replay_buffer.batch_size=24 \ + replay_buffer.batch_length=12 \ + networks.rssm_hidden_dim=17 """, } diff --git a/test/test_cost.py b/test/test_cost.py index 61ff0517024..8bbb7edce05 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -18,7 +18,6 @@ import torch from packaging import version, version as pack_version - from tensordict import assert_allclose_td, TensorDict, TensorDictBase from tensordict._C import unravel_keys from tensordict.nn import ( @@ -37,6 +36,7 @@ TensorDictSequential as Seq, WrapModule, ) +from tensordict.nn.distributions.composite import _add_suffix from tensordict.nn.utils import Buffer from tensordict.utils import unravel_key from torch import autograd, nn @@ -199,6 +199,13 @@ def get_devices(): class LossModuleTestBase: + @pytest.fixture(scope="class", autouse=True) + def _composite_log_prob(self): + setter = set_composite_lp_aggregate(False) + setter.set() + yield + setter.unset() + def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) assert hasattr( @@ -3541,13 +3548,6 @@ def test_td3bc_reduction(self, reduction): class TestSAC(LossModuleTestBase): seed = 0 - @pytest.fixture(scope="class", autouse=True) - def _composite_log_prob(self): - setter = set_composite_lp_aggregate(False) - setter.set() - yield - setter.unset() - def _create_mock_actor( self, batch=2, @@ -4623,13 +4623,6 @@ def test_sac_reduction(self, reduction, version, composite_action_dist): class TestDiscreteSAC(LossModuleTestBase): seed = 0 - @pytest.fixture(scope="class", autouse=True) - def _composite_log_prob(self): - setter = set_composite_lp_aggregate(False) - setter.set() - yield - setter.unset() - def _create_mock_actor( self, batch=2, @@ -6786,7 +6779,7 @@ def test_redq_tensordict_keys(self, td_est): "priority": "td_error", "action": "action", "value": "state_value", - "sample_log_prob": "sample_log_prob", + "sample_log_prob": "action_log_prob", "state_action_value": "state_action_value", "reward": "reward", "done": "done", @@ -6849,12 +6842,22 @@ def test_redq_notensordict( actor_network=actor, qvalue_network=qvalue, ) - loss.set_keys( - action=action_key, - reward=reward_key, - done=done_key, - terminated=terminated_key, - ) + if deprec: + loss.set_keys( + action=action_key, + reward=reward_key, + done=done_key, + terminated=terminated_key, + log_prob=_add_suffix(action_key, "_log_prob"), + ) + else: + loss.set_keys( + action=action_key, + reward=reward_key, + done=done_key, + terminated=terminated_key, + sample_log_prob=_add_suffix(action_key, "_log_prob"), + ) kwargs = { action_key: td.get(action_key), @@ -7916,13 +7919,6 @@ def test_dcql_reduction(self, reduction): class TestPPO(LossModuleTestBase): seed = 0 - @pytest.fixture(scope="class", autouse=True) - def _composite_log_prob(self): - setter = set_composite_lp_aggregate(False) - setter.set() - yield - setter.unset() - def _create_mock_actor( self, batch=2, @@ -8003,7 +7999,7 @@ def _create_mock_actor_value( action_dim=4, device="cpu", composite_action_dist=False, - sample_log_prob_key="sample_log_prob", + sample_log_prob_key="action_log_prob", ): # Actor action_spec = Bounded( @@ -8058,7 +8054,7 @@ def _create_mock_actor_value_shared( action_dim=4, device="cpu", composite_action_dist=False, - sample_log_prob_key="sample_log_prob", + sample_log_prob_key="action_log_prob", ): # Actor action_spec = Bounded( @@ -8123,7 +8119,7 @@ def _create_mock_data_ppo( reward_key="reward", done_key="done", terminated_key="terminated", - sample_log_prob_key="sample_log_prob", + sample_log_prob_key="action_log_prob", composite_action_dist=False, ): # create a tensordict @@ -8834,7 +8830,7 @@ def test_ppo_tensordict_keys_run(self, loss_class, advantage, td_est): "advantage": "advantage_test", "value_target": "value_target_test", "value": "state_value_test", - "sample_log_prob": "sample_log_prob_test", + "sample_log_prob": "action_log_prob_test", "action": "action_test", } @@ -9242,13 +9238,6 @@ def mixture_constructor(logits, loc, scale): class TestA2C(LossModuleTestBase): seed = 0 - @pytest.fixture(scope="class", autouse=True) - def _composite_log_prob(self): - setter = set_composite_lp_aggregate(False) - setter.set() - yield - setter.unset() - def _create_mock_actor( self, batch=2, @@ -9814,7 +9803,7 @@ def test_a2c_tensordict_keys_run( value_key = "state_value_test" action_key = "action_test" reward_key = "reward_test" - sample_log_prob_key = "sample_log_prob_test" + sample_log_prob_key = "action_log_prob_test" done_key = ("done", "test") terminated_key = ("terminated", "test") @@ -10258,7 +10247,7 @@ def test_reinforce_tensordict_keys(self, td_est): "advantage": "advantage", "value_target": "value_target", "value": "state_value", - "sample_log_prob": "sample_log_prob", + "sample_log_prob": "action_log_prob", "reward": "reward", "done": "done", "terminated": "terminated", @@ -10316,7 +10305,7 @@ def _create_mock_common_layer_setup( { "obs": torch.randn(*batch, n_obs), "action": torch.randn(*batch, n_act), - "sample_log_prob": torch.randn(*batch), + "action_log_prob": torch.randn(*batch), "done": torch.zeros(*batch, 1, dtype=torch.bool), "terminated": torch.zeros(*batch, 1, dtype=torch.bool), "next": { @@ -11788,7 +11777,7 @@ def _create_mock_common_layer_setup( { "obs": torch.randn(*batch, n_obs), "action": torch.randn(*batch, n_act), - "sample_log_prob": torch.randn(*batch), + "action_log_prob": torch.randn(*batch), "done": torch.zeros(*batch, 1, dtype=torch.bool), "terminated": torch.zeros(*batch, 1, dtype=torch.bool), "next": { @@ -12604,7 +12593,7 @@ def _create_mock_common_layer_setup( { "obs": torch.randn(*batch, n_obs), "action": torch.randn(*batch, n_act), - "sample_log_prob": torch.randn(*batch), + "action_log_prob": torch.randn(*batch), "done": torch.zeros(*batch, 1, dtype=torch.bool), "terminated": torch.zeros(*batch, 1, dtype=torch.bool), "next": { @@ -15228,6 +15217,7 @@ def test_successive_traj_gae( ["half", torch.half, "cpu"], ], ) +@set_composite_lp_aggregate(False) def test_shared_params(dest, expected_dtype, expected_device): if torch.cuda.device_count() == 0 and dest == "cuda": pytest.skip("no cuda device available") @@ -15332,6 +15322,13 @@ def _forward_value_estimator_keys(self, **kwargs) -> None: class TestAdv: + @pytest.fixture(scope="class", autouse=True) + def _composite_log_prob(self): + setter = set_composite_lp_aggregate(False) + setter.set() + yield + setter.unset() + @pytest.mark.parametrize( "adv,kwargs", [ @@ -15369,7 +15366,7 @@ def test_dispatch( ) kwargs = { "obs": torch.randn(1, 10, 3), - "sample_log_prob": torch.log(torch.rand(1, 10, 1)), + "action_log_prob": torch.log(torch.rand(1, 10, 1)), "next_reward": torch.randn(1, 10, 1, requires_grad=True), "next_done": torch.zeros(1, 10, 1, dtype=torch.bool), "next_terminated": torch.zeros(1, 10, 1, dtype=torch.bool), @@ -15431,7 +15428,7 @@ def test_diff_reward( td = TensorDict( { "obs": torch.randn(1, 10, 3), - "sample_log_prob": torch.log(torch.rand(1, 10, 1)), + "action_log_prob": torch.log(torch.rand(1, 10, 1)), "next": { "obs": torch.randn(1, 10, 3), "reward": torch.randn(1, 10, 1, requires_grad=True), @@ -15504,7 +15501,7 @@ def test_non_differentiable(self, adv, shifted, kwargs): td = TensorDict( { "obs": torch.randn(1, 10, 3), - "sample_log_prob": torch.log(torch.rand(1, 10, 1)), + "action_log_prob": torch.log(torch.rand(1, 10, 1)), "next": { "obs": torch.randn(1, 10, 3), "reward": torch.randn(1, 10, 1, requires_grad=True), @@ -15575,7 +15572,7 @@ def test_time_dim(self, adv, kwargs, shifted=True): td = TensorDict( { "obs": torch.randn(1, 10, 3), - "sample_log_prob": torch.log(torch.rand(1, 10, 1)), + "action_log_prob": torch.log(torch.rand(1, 10, 1)), "next": { "obs": torch.randn(1, 10, 3), "reward": torch.randn(1, 10, 1, requires_grad=True), @@ -15676,7 +15673,7 @@ def test_skip_existing( td = TensorDict( { "obs": torch.randn(1, 10, 3), - "sample_log_prob": torch.log(torch.rand(1, 10, 1)), + "action_log_prob": torch.log(torch.rand(1, 10, 1)), "state_value": torch.ones(1, 10, 1), "next": { "obs": torch.randn(1, 10, 3), @@ -15814,6 +15811,13 @@ def test_set_deprecated_keys(self, adv, kwargs): class TestBase: + @pytest.fixture(scope="class", autouse=True) + def _composite_log_prob(self): + setter = set_composite_lp_aggregate(False) + setter.set() + yield + setter.unset() + def test_decorators(self): class MyLoss(LossModule): def forward(self, tensordict: TensorDictBase) -> TensorDictBase: @@ -16033,6 +16037,13 @@ class _AcceptedKeys: class TestUtils: + @pytest.fixture(scope="class", autouse=True) + def _composite_log_prob(self): + setter = set_composite_lp_aggregate(False) + setter.set() + yield + setter.unset() + @pytest.mark.parametrize("B", [None, (1, ), (4, ), (2, 2, ), (1, 2, 8, )]) # fmt: skip @pytest.mark.parametrize("T", [1, 10]) @pytest.mark.parametrize("device", get_default_devices()) @@ -16203,6 +16214,7 @@ def fun(a, b, time_dim=-2): (SoftUpdate, {"eps": 0.99}), ], ) +@set_composite_lp_aggregate(False) def test_updater_warning(updater, kwarg): with warnings.catch_warnings(): dqn = DQNLoss(torch.nn.Linear(3, 4), delay_value=True, action_space="one_hot") @@ -16215,6 +16227,13 @@ def test_updater_warning(updater, kwarg): class TestSingleCall: + @pytest.fixture(scope="class", autouse=True) + def _composite_log_prob(self): + setter = set_composite_lp_aggregate(False) + setter.set() + yield + setter.unset() + def _mock_value_net(self, has_target, value_key): model = nn.Linear(3, 1) module = TensorDictModule(model, in_keys=["obs"], out_keys=[value_key]) @@ -16267,6 +16286,7 @@ def test_single_call(self, has_target, value_key, single_call, detach_next=True) assert (value != value_).all() +@set_composite_lp_aggregate(False) def test_instantiate_with_different_keys(): loss_1 = DQNLoss( value_network=nn.Linear(3, 3), action_space="one_hot", delay_value=True @@ -16281,6 +16301,13 @@ def test_instantiate_with_different_keys(): class TestBuffer: + @pytest.fixture(scope="class", autouse=True) + def _composite_log_prob(self): + setter = set_composite_lp_aggregate(False) + setter.set() + yield + setter.unset() + # @pytest.mark.parametrize('dtype', (torch.double, torch.float, torch.half)) # def test_param_cast(self, dtype): # param = nn.Parameter(torch.zeros(3)) @@ -16390,6 +16417,7 @@ def __init__(self): TORCH_VERSION < version.parse("2.5.0"), reason="requires torch>=2.5" ) @pytest.mark.skipif(IS_WINDOWS, reason="windows tests do not support compile") +@set_composite_lp_aggregate(False) def test_exploration_compile(): try: torch._dynamo.reset_code_caches() @@ -16456,6 +16484,7 @@ def func(t): assert it == exploration_type() +@set_composite_lp_aggregate(False) def test_loss_exploration(): class DummyLoss(LossModule): def forward(self, td, mode): diff --git a/torchrl/objectives/deprecated.py b/torchrl/objectives/deprecated.py index 879552254ee..1e736e878dc 100644 --- a/torchrl/objectives/deprecated.py +++ b/torchrl/objectives/deprecated.py @@ -13,7 +13,7 @@ import torch from tensordict import TensorDict, TensorDictBase, TensorDictParams -from tensordict.nn import dispatch, TensorDictModule +from tensordict.nn import composite_lp_aggregate, dispatch, TensorDictModule from tensordict.utils import NestedKey from torch import Tensor @@ -121,12 +121,20 @@ class _AcceptedKeys: action: NestedKey = "action" state_action_value: NestedKey = "state_action_value" value: NestedKey = "state_value" - log_prob: NestedKey = "_log_prob" + log_prob: NestedKey | None = None priority: NestedKey = "td_error" reward: NestedKey = "reward" done: NestedKey = "done" terminated: NestedKey = "terminated" + def __post_init__(self): + if self.log_prob is None: + if composite_lp_aggregate(nowarn=True): + self.log_prob = "sample_log_prob" + else: + self.log_prob = "action_log_prob" + + tensor_keys: _AcceptedKeys default_keys = _AcceptedKeys delay_actor: bool = False default_value_estimator = ValueEstimators.TD0 @@ -358,12 +366,14 @@ def _actor_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, Tensor]: tensordict_clone.select(*self.qvalue_network.in_keys, strict=False), self._cached_detach_qvalue_network_params, ) - state_action_value = tensordict_expand.get("state_action_value").squeeze(-1) + state_action_value = tensordict_expand.get( + self.tensor_keys.state_action_value + ).squeeze(-1) loss_actor = -( state_action_value - - self.alpha * tensordict_clone.get("sample_log_prob").squeeze(-1) + - self.alpha * tensordict_clone.get(self.tensor_keys.log_prob).squeeze(-1) ) - return loss_actor, tensordict_clone.get("sample_log_prob") + return loss_actor, tensordict_clone.get(self.tensor_keys.log_prob) def _qvalue_loss(self, tensordict: TensorDictBase) -> Tensor: tensordict_save = tensordict @@ -388,30 +398,33 @@ def _qvalue_loss(self, tensordict: TensorDictBase) -> Tensor: ExplorationType.RANDOM ), self.target_actor_network_params.to_module(self.actor_network): self.actor_network(next_td) - sample_log_prob = next_td.get("sample_log_prob") + sample_log_prob = next_td.get(self.tensor_keys.log_prob) # get q-values next_td = self._vmap_qvalue_networkN0( next_td, selected_q_params, ) - state_action_value = next_td.get("state_action_value") + state_action_value = next_td.get(self.tensor_keys.state_action_value) if ( state_action_value.shape[-len(sample_log_prob.shape) :] != sample_log_prob.shape ): sample_log_prob = sample_log_prob.unsqueeze(-1) next_state_value = ( - next_td.get("state_action_value") - self.alpha * sample_log_prob + next_td.get(self.tensor_keys.state_action_value) + - self.alpha * sample_log_prob ) next_state_value = next_state_value.min(0)[0] - tensordict.set(("next", "state_value"), next_state_value) + tensordict.set(("next", self.tensor_keys.value), next_state_value) target_value = self.value_estimator.value_estimate(tensordict).squeeze(-1) tensordict_expand = self._vmap_qvalue_networkN0( tensordict.select(*self.qvalue_network.in_keys, strict=False), self.qvalue_network_params, ) - pred_val = tensordict_expand.get("state_action_value").squeeze(-1) + pred_val = tensordict_expand.get(self.tensor_keys.state_action_value).squeeze( + -1 + ) td_error = abs(pred_val - target_value) loss_qval = distance_loss( pred_val, diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index 221b6924fbf..46783866588 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -317,7 +317,6 @@ def __post_init__(self): target_actor_network_params: TensorDictParams target_critic_network_params: TensorDictParams - @set_composite_lp_aggregate(False) def __init__( self, actor_network: ProbabilisticTensorDictSequential | None = None, @@ -487,7 +486,6 @@ def _forward_value_estimator_keys(self, **kwargs) -> None: def reset(self) -> None: pass - @set_composite_lp_aggregate(False) def get_entropy_bonus(self, dist: d.Distribution) -> torch.Tensor: try: entropy = dist.entropy() @@ -496,20 +494,21 @@ def get_entropy_bonus(self, dist: d.Distribution) -> torch.Tensor: x = dist.rsample((self.samples_mc_entropy,)) else: x = dist.sample((self.samples_mc_entropy,)) - log_prob = dist.log_prob(x) - - if is_tensor_collection(log_prob): - if isinstance(self.tensor_keys.sample_log_prob, NestedKey): - log_prob = log_prob.get(self.tensor_keys.sample_log_prob) - else: - log_prob = log_prob.select(*self.tensor_keys.sample_log_prob) + with set_composite_lp_aggregate(False) if isinstance( + dist, CompositeDistribution + ) else contextlib.nullcontext(): + log_prob = dist.log_prob(x) + if is_tensor_collection(log_prob): + if isinstance(self.tensor_keys.sample_log_prob, NestedKey): + log_prob = log_prob.get(self.tensor_keys.sample_log_prob) + else: + log_prob = log_prob.select(*self.tensor_keys.sample_log_prob) entropy = -log_prob.mean(0) if is_tensor_collection(entropy): entropy = _sum_td_features(entropy) return entropy.unsqueeze(-1) - @set_composite_lp_aggregate(False) def _log_weight( self, tensordict: TensorDictBase ) -> Tuple[torch.Tensor, d.Distribution]: @@ -550,21 +549,22 @@ def _log_weight( ) log_prob = dist.log_prob(action) if is_composite: - if not is_tensor_collection(prev_log_prob): - # this isn't great, in general multihead actions should have a composite log-prob too - warnings.warn( - "You are using a composite distribution, yet your log-probability is a tensor. " - "Make sure you have called tensordict.nn.set_composite_lp_aggregate(False).set() at " - "the beginning of your script to get a proper composite log-prob.", - category=UserWarning, - ) - if ( - is_composite - and not is_tensor_collection(prev_log_prob) - and is_tensor_collection(log_prob) - ): - log_prob = _sum_td_features(log_prob) - log_prob.view_as(prev_log_prob) + with set_composite_lp_aggregate(False): + if not is_tensor_collection(prev_log_prob): + # this isn't great, in general multihead actions should have a composite log-prob too + warnings.warn( + "You are using a composite distribution, yet your log-probability is a tensor. " + "Make sure you have called tensordict.nn.set_composite_lp_aggregate(False).set() at " + "the beginning of your script to get a proper composite log-prob.", + category=UserWarning, + ) + if ( + is_composite + and not is_tensor_collection(prev_log_prob) + and is_tensor_collection(log_prob) + ): + log_prob = _sum_td_features(log_prob) + log_prob.view_as(prev_log_prob) log_weight = (log_prob - prev_log_prob).unsqueeze(-1) kl_approx = (prev_log_prob - log_prob).unsqueeze(-1) @@ -1215,11 +1215,14 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict: self.actor_network ) if self.functional else contextlib.nullcontext(): current_dist = self.actor_network.get_dist(tensordict_copy) + is_composite = isinstance(current_dist, CompositeDistribution) try: kl = torch.distributions.kl.kl_divergence(previous_dist, current_dist) except NotImplementedError: x = previous_dist.sample((self.samples_mc_kl,)) - with set_composite_lp_aggregate(False): + with set_composite_lp_aggregate( + False + ) if is_composite else contextlib.nullcontext(): previous_log_prob = previous_dist.log_prob(x) current_log_prob = current_dist.log_prob(x) if is_tensor_collection(previous_log_prob): diff --git a/torchrl/objectives/redq.py b/torchrl/objectives/redq.py index bbd8621212b..6e280e1f0fa 100644 --- a/torchrl/objectives/redq.py +++ b/torchrl/objectives/redq.py @@ -240,6 +240,7 @@ def __post_init__(self): else: self.sample_log_prob = "action_log_prob" + tensor_keys: _AcceptedKeys default_keys = _AcceptedKeys delay_actor: bool = False default_value_estimator = ValueEstimators.TD0 From b05d735ba12295b39717be445a33e8504a4d02d6 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 16 Jan 2025 11:10:50 +0000 Subject: [PATCH 7/7] Update [ghstack-poisoned] --- examples/agents/composite_ppo.py | 187 ++++++++++++++----------------- torchrl/objectives/ppo.py | 2 + 2 files changed, 89 insertions(+), 100 deletions(-) diff --git a/examples/agents/composite_ppo.py b/examples/agents/composite_ppo.py index d75ce3218b3..501dceb651d 100644 --- a/examples/agents/composite_ppo.py +++ b/examples/agents/composite_ppo.py @@ -4,28 +4,71 @@ # LICENSE file in the root directory of this source tree. """ -Multi-head agent and PPO loss +Multi-head Agent and PPO Loss ============================= - This example demonstrates how to use TorchRL to create a multi-head agent with three separate distributions (Gamma, Kumaraswamy, and Mixture) and train it using Proximal Policy Optimization (PPO) losses. -The code first defines a module `make_params` that extracts the parameters of the distributions from an input tensordict. -It then creates a `dist_constructor` function that takes these parameters as input and outputs a CompositeDistribution -object containing the three distributions. - -The policy is defined as a ProbabilisticTensorDictSequential module that reads an observation, casts it to parameters, -creates a distribution from these parameters, and samples from the distribution to output multiple actions. - -The example tests the policy with fake data across three different PPO losses: PPOLoss, ClipPPOLoss, and KLPENPPOLoss. - -Note that the `log_prob` method of the CompositeDistribution object can return either an aggregated tensor or a -fine-grained tensordict with individual log-probabilities, depending on the value of the `aggregate_probabilities` -argument. The PPO loss modules are designed to handle both cases, and will default to `aggregate_probabilities=False` -if not specified. - -In particular, if `aggregate_probabilities=False` and `include_sum=True`, the summed log-probs will also be included in -the output tensordict. However, since we have access to the individual log-probs, this feature is not typically used. +Step-by-step Explanation +------------------------ + +1. **Setting Composite Log-Probabilities**: + - To use composite (=multi-head0 distributions with PPO (or any other algorithm that relies on probability distributions like SAC + or A2C), you must call `set_composite_lp_aggregate(False).set()`. Not calling this will result in errors during + execution of your script. + - From torchrl and tensordict v0.9, this will be the default behavior. Not doing this will result in + `CompositeDistribution` aggregating the log-probs, which may lead to incorrect log-probabilities. + - Note that `set_composite_lp_aggregate(False).set()` will cause the sample log-probabilities to be named + `_log_prob` for any probability distribution, not just composite ones. For regular, single-head policies + for instance, the log-probability will be named `"action_log_prob"`. + Previously, log-prob keys defaulted to `sample_log_prob`. +2. **Action Grouping**: + - Actions can be grouped or not; PPO doesn't require them to be grouped. + - If actions are grouped, calling the policy will result in a `TensorDict` with fields for each agent's action and + log-probability, e.g., `agent0`, `agent0_log_prob`, etc. + + ... [...] + ... action: TensorDict( + ... fields={ + ... agent0: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False), + ... agent0_log_prob: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False), + ... agent1: Tensor(shape=torch.Size([4, 2]), device=cpu, dtype=torch.float32, is_shared=False), + ... agent1_log_prob: Tensor(shape=torch.Size([4, 2]), device=cpu, dtype=torch.float32, is_shared=False), + ... agent2: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False), + ... agent2_log_prob: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False)}, + ... batch_size=torch.Size([4]), + ... device=None, + ... is_shared=False), + + - If actions are not grouped, each agent will have its own `TensorDict` with `action` and `action_log_prob` fields. + + ... [...] + ... agent0: TensorDict( + ... fields={ + ... action: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False), + ... action_log_prob: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False)}, + ... batch_size=torch.Size([4]), + ... device=None, + ... is_shared=False), + ... agent1: TensorDict( + ... fields={ + ... action: Tensor(shape=torch.Size([4, 2]), device=cpu, dtype=torch.float32, is_shared=False), + ... action_log_prob: Tensor(shape=torch.Size([4, 2]), device=cpu, dtype=torch.float32, is_shared=False)}, + ... batch_size=torch.Size([4]), + ... device=None, + ... is_shared=False), + ... agent2: TensorDict( + ... fields={ + ... action: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False), + ... action_log_prob: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False)}, + ... batch_size=torch.Size([4]), + ... device=None, + ... is_shared=False), + +3. **PPO Loss Calculation**: + - Under the hood, `ClipPPO` will clip individual weights (not the aggregate) and multiply that by the advantage. + +The code below sets up a multi-head agent with three distributions and demonstrates how to train it using PPO losses. """ @@ -38,6 +81,7 @@ InteractionType, ProbabilisticTensorDictModule as Prob, ProbabilisticTensorDictSequential as ProbSeq, + set_composite_lp_aggregate, TensorDictModule as Mod, TensorDictSequential as Seq, WrapModule as Wrap, @@ -45,6 +89,10 @@ from torch import distributions as d from torchrl.objectives import ClipPPOLoss, KLPENPPOLoss, PPOLoss +set_composite_lp_aggregate(False).set() + +GROUPED_ACTIONS = False + make_params = Mod( lambda: ( torch.ones(4), @@ -74,8 +122,18 @@ def mixture_constructor(logits, loc, scale): ) -# ============================================================================= -# Example 0: aggregate_probabilities=None (default) =========================== +if GROUPED_ACTIONS: + name_map = { + "gamma": ("action", "agent0"), + "Kumaraswamy": ("action", "agent1"), + "mixture": ("action", "agent2"), + } +else: + name_map = { + "gamma": ("agent0", "action"), + "Kumaraswamy": ("agent1", "action"), + "mixture": ("agent2", "action"), + } dist_constructor = functools.partial( CompositeDistribution, @@ -84,12 +142,7 @@ def mixture_constructor(logits, loc, scale): "Kumaraswamy": d.Kumaraswamy, "mixture": mixture_constructor, }, - name_map={ - "gamma": ("agent0", "action"), - "Kumaraswamy": ("agent1", "action"), - "mixture": ("agent2", "action"), - }, - aggregate_probabilities=None, + name_map=name_map, ) @@ -97,7 +150,7 @@ def mixture_constructor(logits, loc, scale): make_params, Prob( in_keys=["params"], - out_keys=[("agent0", "action"), ("agent1", "action"), ("agent2", "action")], + out_keys=list(name_map.values()), distribution_class=dist_constructor, return_log_prob=True, default_interaction_type=InteractionType.RANDOM, @@ -105,19 +158,11 @@ def mixture_constructor(logits, loc, scale): ) td = policy(TensorDict(batch_size=[4])) -print("0. result of policy call", td) +print("Result of policy call", td) dist = policy.get_dist(td) -log_prob = dist.log_prob( - td, aggregate_probabilities=False, inplace=False, include_sum=False -) -print("0. non-aggregated log-prob") - -# We can also get the log-prob from the policy directly -log_prob = policy.log_prob( - td, aggregate_probabilities=False, inplace=False, include_sum=False -) -print("0. non-aggregated log-prob (from policy)") +log_prob = dist.log_prob(td) +print("Composite log-prob", log_prob) # Build a dummy value operator value_operator = Seq( @@ -134,70 +179,12 @@ def mixture_constructor(logits, loc, scale): TensorDict(reward=torch.randn(4, 1), done=torch.zeros(4, 1, dtype=torch.bool)), ) -# Instantiate the loss +# Instantiate the loss - test the 3 different PPO losses for loss_cls in (PPOLoss, ClipPPOLoss, KLPENPPOLoss): + # PPO sets the keys automatically by looking at the policy ppo = loss_cls(policy, value_operator) - - # Keys are not the default ones - there is more than one action - ppo.set_keys( - action=[("agent0", "action"), ("agent1", "action"), ("agent2", "action")], - sample_log_prob=[ - ("agent0", "action_log_prob"), - ("agent1", "action_log_prob"), - ("agent2", "action_log_prob"), - ], - ) - - # Get the loss values - loss_vals = ppo(data) - print("0. ", loss_cls, loss_vals) - - -# =================================================================== -# Example 1: aggregate_probabilities=True =========================== - -dist_constructor.keywords["aggregate_probabilities"] = True - -td = policy(TensorDict(batch_size=[4])) -print("1. result of policy call", td) - -# Instantiate the loss -for loss_cls in (PPOLoss, ClipPPOLoss, KLPENPPOLoss): - ppo = loss_cls(policy, value_operator) - - # Keys are not the default ones - there is more than one action. No need to indicate the sample-log-prob key, since - # there is only one. - ppo.set_keys( - action=[("agent0", "action"), ("agent1", "action"), ("agent2", "action")] - ) - - # Get the loss values - loss_vals = ppo(data) - print("1. ", loss_cls, loss_vals) - - -# =================================================================== -# Example 2: aggregate_probabilities=False =========================== - -dist_constructor.keywords["aggregate_probabilities"] = False - -td = policy(TensorDict(batch_size=[4])) -print("2. result of policy call", td) - -# Instantiate the loss -for loss_cls in (PPOLoss, ClipPPOLoss, KLPENPPOLoss): - ppo = loss_cls(policy, value_operator) - - # Keys are not the default ones - there is more than one action - ppo.set_keys( - action=[("agent0", "action"), ("agent1", "action"), ("agent2", "action")], - sample_log_prob=[ - ("agent0", "action_log_prob"), - ("agent1", "action_log_prob"), - ("agent2", "action_log_prob"), - ], - ) + print("tensor keys", ppo.tensor_keys) # Get the loss values loss_vals = ppo(data) - print("2. ", loss_cls, loss_vals) + print("Loss result:", loss_cls, loss_vals) diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index a5254dacddd..bf7831d518c 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -959,6 +959,8 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: # to different, unrelated trajectories, which is not standard. Still it can give a idea of the dispersion # of the weights. lw = log_weight.squeeze() + if not isinstance(lw, torch.Tensor): + lw = _sum_td_features(lw) ess = (2 * lw.logsumexp(0) - (2 * lw).logsumexp(0)).exp() batch = log_weight.shape[0]