Skip to content

Commit

Permalink
[Feature] Make PPO compatible with composite actions and log-probs
Browse files Browse the repository at this point in the history
ghstack-source-id: cbdaf533a39aeea41e3fbcda4e9d95a116eabfe1
Pull Request resolved: #2665
  • Loading branch information
vmoens committed Dec 18, 2024
1 parent 1ce25f1 commit 94b60fe
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 31 deletions.
99 changes: 70 additions & 29 deletions torchrl/objectives/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,15 @@

from copy import deepcopy
from dataclasses import dataclass
from typing import Tuple
from typing import List, Tuple

import torch
from tensordict import (
is_tensor_collection,
TensorDict,
TensorDictBase,
TensorDictParams,
unravel_key,
)
from tensordict.nn import (
CompositeDistribution,
Expand All @@ -33,6 +34,8 @@
_cache_values,
_clip_value_loss,
_GAMMA_LMBDA_DEPREC_ERROR,
_maybe_add_or_extend_key,
_maybe_get_or_select,
_reduce,
_sum_td_features,
default_value_kwargs,
Expand Down Expand Up @@ -67,7 +70,10 @@ class PPOLoss(LossModule):
Args:
actor_network (ProbabilisticTensorDictSequential): policy operator.
critic_network (ValueOperator): value operator.
Typically a :class:`~tensordict.nn.ProbabilisticTensorDictSequential` subclass taking observations
as input and outputting an action (or actions) as well as its log-probability value.
critic_network (ValueOperator): value operator. The critic will usually take the observations as input
and return a scalar value (``state_value`` by default) in the output keys.
Keyword Args:
entropy_bonus (bool, optional): if ``True``, an entropy bonus will be added to the
Expand Down Expand Up @@ -267,28 +273,28 @@ class _AcceptedKeys:
Will be used for the underlying value estimator Defaults to ``"value_target"``.
value (NestedKey): The input tensordict key where the state value is expected.
Will be used for the underlying value estimator. Defaults to ``"state_value"``.
sample_log_prob (NestedKey): The input tensordict key where the
sample_log_prob (NestedKey or list of nested keys): The input tensordict key where the
sample log probability is expected. Defaults to ``"sample_log_prob"``.
action (NestedKey): The input tensordict key where the action is expected.
action (NestedKey or list of nested keys): The input tensordict key where the action is expected.
Defaults to ``"action"``.
reward (NestedKey): The input tensordict key where the reward is expected.
reward (NestedKey or list of nested keys): The input tensordict key where the reward is expected.
Will be used for the underlying value estimator. Defaults to ``"reward"``.
done (NestedKey): The key in the input TensorDict that indicates
done (NestedKey or list of nested keys): The key in the input TensorDict that indicates
whether a trajectory is done. Will be used for the underlying value estimator.
Defaults to ``"done"``.
terminated (NestedKey): The key in the input TensorDict that indicates
terminated (NestedKey or list of nested keys): The key in the input TensorDict that indicates
whether a trajectory is terminated. Will be used for the underlying value estimator.
Defaults to ``"terminated"``.
"""

advantage: NestedKey = "advantage"
value_target: NestedKey = "value_target"
value: NestedKey = "state_value"
sample_log_prob: NestedKey = "sample_log_prob"
action: NestedKey = "action"
reward: NestedKey = "reward"
done: NestedKey = "done"
terminated: NestedKey = "terminated"
sample_log_prob: NestedKey | List[NestedKey] = "sample_log_prob"
action: NestedKey | List[NestedKey] = "action"
reward: NestedKey | List[NestedKey] = "reward"
done: NestedKey | List[NestedKey] = "done"
terminated: NestedKey | List[NestedKey] = "terminated"

default_keys = _AcceptedKeys()
default_value_estimator = ValueEstimators.GAE
Expand Down Expand Up @@ -369,7 +375,7 @@ def __init__(

try:
device = next(self.parameters()).device
except AttributeError:
except (AttributeError, StopIteration):
device = torch.device("cpu")

self.register_buffer("entropy_coef", torch.tensor(entropy_coef, device=device))
Expand Down Expand Up @@ -409,15 +415,36 @@ def functional(self):

def _set_in_keys(self):
keys = [
self.tensor_keys.action,
self.tensor_keys.sample_log_prob,
("next", self.tensor_keys.reward),
("next", self.tensor_keys.done),
("next", self.tensor_keys.terminated),
*self.actor_network.in_keys,
*[("next", key) for key in self.actor_network.in_keys],
*self.critic_network.in_keys,
]

if isinstance(self.tensor_keys.action, NestedKey):
keys.append(self.tensor_keys.action)
else:
keys.extend(self.tensor_keys.action)

if isinstance(self.tensor_keys.sample_log_prob, NestedKey):
keys.append(self.tensor_keys.sample_log_prob)
else:
keys.extend(self.tensor_keys.sample_log_prob)

if isinstance(self.tensor_keys.reward, NestedKey):
keys.append(unravel_key(("next", self.tensor_keys.reward)))
else:
keys.extend([unravel_key(("next", k)) for k in self.tensor_keys.reward])

if isinstance(self.tensor_keys.done, NestedKey):
keys.append(unravel_key(("next", self.tensor_keys.done)))
else:
keys.extend([unravel_key(("next", k)) for k in self.tensor_keys.done])

if isinstance(self.tensor_keys.terminated, NestedKey):
keys.append(unravel_key(("next", self.tensor_keys.terminated)))
else:
keys.extend([unravel_key(("next", k)) for k in self.tensor_keys.terminated])

self._in_keys = list(set(keys))

@property
Expand Down Expand Up @@ -472,25 +499,38 @@ def get_entropy_bonus(self, dist: d.Distribution) -> torch.Tensor:
if is_tensor_collection(entropy):
entropy = _sum_td_features(entropy)
except NotImplementedError:
x = dist.rsample((self.samples_mc_entropy,))
if getattr(dist, "has_rsample", False):
x = dist.rsample((self.samples_mc_entropy,))
else:
x = dist.sample((self.samples_mc_entropy,))
log_prob = dist.log_prob(x)
if is_tensor_collection(log_prob):

if is_tensor_collection(log_prob) and isinstance(
self.tensor_keys.sample_log_prob, NestedKey
):
log_prob = log_prob.get(self.tensor_keys.sample_log_prob)
else:
log_prob = log_prob.select(*self.tensor_keys.sample_log_prob)

entropy = -log_prob.mean(0)
return entropy.unsqueeze(-1)

def _log_weight(
self, tensordict: TensorDictBase
) -> Tuple[torch.Tensor, d.Distribution]:

# current log_prob of actions
action = tensordict.get(self.tensor_keys.action)
action = _maybe_get_or_select(tensordict, self.tensor_keys.action)

with self.actor_network_params.to_module(
self.actor_network
) if self.functional else contextlib.nullcontext():
dist = self.actor_network.get_dist(tensordict)

prev_log_prob = tensordict.get(self.tensor_keys.sample_log_prob)
prev_log_prob = _maybe_get_or_select(
tensordict, self.tensor_keys.sample_log_prob
)

if prev_log_prob.requires_grad:
raise RuntimeError(
f"tensordict stored {self.tensor_keys.sample_log_prob} requires grad."
Expand All @@ -513,8 +553,8 @@ def _log_weight(
else:
is_composite = False
kwargs = {}
log_prob = dist.log_prob(tensordict, **kwargs)
if is_composite and not isinstance(prev_log_prob, TensorDict):
log_prob: TensorDictBase = dist.log_prob(tensordict, **kwargs)
if is_composite and not is_tensor_collection(prev_log_prob):
log_prob = _sum_td_features(log_prob)
log_prob.view_as(prev_log_prob)

Expand Down Expand Up @@ -1088,15 +1128,16 @@ def __init__(

def _set_in_keys(self):
keys = [
self.tensor_keys.action,
self.tensor_keys.sample_log_prob,
("next", self.tensor_keys.reward),
("next", self.tensor_keys.done),
("next", self.tensor_keys.terminated),
*self.actor_network.in_keys,
*[("next", key) for key in self.actor_network.in_keys],
*self.critic_network.in_keys,
]
_maybe_add_or_extend_key(keys, self.tensor_keys.action)
_maybe_add_or_extend_key(keys, self.tensor_keys.sample_log_prob)
_maybe_add_or_extend_key(keys, self.tensor_keys.reward, "next")
_maybe_add_or_extend_key(keys, self.tensor_keys.done, "next")
_maybe_add_or_extend_key(keys, self.tensor_keys.terminated, "next")

# Get the parameter keys from the actor dist
actor_dist_module = None
for module in self.actor_network.modules():
Expand Down
27 changes: 25 additions & 2 deletions torchrl/objectives/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
import re
import warnings
from enum import Enum
from typing import Iterable, Optional, Union
from typing import Iterable, List, Optional, Union

import torch
from tensordict import TensorDict, TensorDictBase
from tensordict import NestedKey, TensorDict, TensorDictBase, unravel_key
from tensordict.nn import TensorDictModule
from torch import nn, Tensor
from torch.nn import functional as F
Expand Down Expand Up @@ -620,3 +620,26 @@ def group_optimizers(*optimizers: torch.optim.Optimizer) -> torch.optim.Optimize
def _sum_td_features(data: TensorDictBase) -> torch.Tensor:
# Sum all features and return a tensor
return data.sum(dim="feature", reduce=True)


def _maybe_get_or_select(td, key_or_keys):
if isinstance(key_or_keys, (str, tuple)):
return td.get(key_or_keys)
return td.select(*key_or_keys)


def _maybe_add_or_extend_key(
tensor_keys: List[NestedKey],
key_or_list_of_keys: NestedKey | List[NestedKey],
prefix: NestedKey = None,
):
if prefix is not None:
if isinstance(key_or_list_of_keys, NestedKey):
tensor_keys.append(unravel_key((prefix, key_or_list_of_keys)))
else:
tensor_keys.extend([unravel_key((prefix, k)) for k in key_or_list_of_keys])
return
if isinstance(key_or_list_of_keys, NestedKey):
tensor_keys.append(key_or_list_of_keys)
else:
tensor_keys.extend(key_or_list_of_keys)

0 comments on commit 94b60fe

Please sign in to comment.