Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] multiagent data standardization: PPO advantages #2677

Merged
merged 9 commits into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 66 additions & 1 deletion torchrl/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from distutils.util import strtobool
from functools import wraps
from importlib import import_module
from typing import Any, Callable, cast, Dict, TypeVar, Union
from typing import Any, Callable, cast, Dict, Tuple, TypeVar, Union

import numpy as np
import torch
Expand Down Expand Up @@ -872,6 +872,71 @@ def set_mode(self, type: Any | None) -> None:
self._mode = type


def _standardize(
input, exclude_dims: Tuple[int] = (), mean=None, std=None, eps: float = None
):
"""Standardizes the input tensor with the possibility of excluding specific dims from the statistics.

Useful when processing multi-agent data to keep the agent dimensions independent.

Args:
input (Tensor): the input tensor to be standardized.
exclude_dims (Sequence[int]): dimensions to exclude from the statistics, can be negative. Default: ().
mean (Tensor): a mean to be used for standardization. Must be of shape broadcastable to input. Default: None.
std (Tensor): a standard deviation to be used for standardization. Must be of shape broadcastable to input. Default: None.
eps (float): epsilon to be used for numerical stability. Default: float32 resolution.

"""
if eps is None:
eps = torch.finfo(torch.float.dtype).resolution

input_shape = input.shape
exclude_dims = [
d if d >= 0 else d + len(input_shape) for d in exclude_dims
] # Make negative dims positive

if len(set(exclude_dims)) != len(exclude_dims):
raise ValueError("Exclude dims has repeating elements")
if any(dim < 0 or dim >= len(input_shape) for dim in exclude_dims):
raise ValueError(
f"exclude_dims={exclude_dims} provided outside bounds for input of shape={input_shape}"
)
if len(exclude_dims) == len(input_shape):
warnings.warn(
"standardize called but all dims were excluded from the statistics, returning unprocessed input"
)
return input
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All of this can be skipped if exclude_dims is empty


if len(exclude_dims):
# Put all excluded dims in the beginning
permutation = list(range(len(input_shape)))
for dim in exclude_dims:
permutation.insert(0, permutation.pop(permutation.index(dim)))
permuted_input = input.permute(*permutation)
else:
permuted_input = input
normalized_shape_len = len(input_shape) - len(exclude_dims)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't need to be computed if exclude_dims is empty


if mean is None:
mean = torch.mean(
permuted_input, keepdim=True, dim=tuple(range(-normalized_shape_len, 0))
)
if std is None:
std = torch.std(
permuted_input, keepdim=True, dim=tuple(range(-normalized_shape_len, 0))
)
vmoens marked this conversation as resolved.
Show resolved Hide resolved
output = (permuted_input - mean) / std.clamp_min(eps)

# Reverse permutation
if len(exclude_dims):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here, we're checking multiple times if exclude_dims is empty when it could be done once

inv_permutation = torch.argsort(
torch.tensor(permutation, dtype=torch.long, device=input.device)
).tolist()
output = torch.permute(output, inv_permutation)

return output


@wraps(torch.compile)
def compile_with_warmup(*args, warmup: int = 1, **kwargs):
"""Compile a model with warm-up.
Expand Down
58 changes: 49 additions & 9 deletions torchrl/objectives/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from tensordict.utils import NestedKey
from torch import distributions as d

from torchrl._utils import _standardize
from torchrl.objectives.common import LossModule

from torchrl.objectives.utils import (
Expand All @@ -46,6 +47,7 @@
TDLambdaEstimator,
VTrace,
)
from yaml import warnings


class PPOLoss(LossModule):
Expand Down Expand Up @@ -87,6 +89,9 @@ class PPOLoss(LossModule):
Can be one of "l1", "l2" or "smooth_l1". Defaults to ``"smooth_l1"``.
normalize_advantage (bool, optional): if ``True``, the advantage will be normalized
before being used. Defaults to ``False``.
normalize_advantage_exclude_dims (Tuple[int], optional): dimensions to exclude from the advantage
standardization. Negative dimensions are valid. This is useful in multiagent (or multiobjective) settings
where the agent (or objective) dimension may be excluded from the reductions. Default: ().
separate_losses (bool, optional): if ``True``, shared parameters between
policy and critic will only be trained on the policy loss.
Defaults to ``False``, i.e., gradients are propagated to shared
Expand Down Expand Up @@ -311,6 +316,7 @@ def __init__(
critic_coef: float = 1.0,
loss_critic_type: str = "smooth_l1",
normalize_advantage: bool = False,
normalize_advantage_exclude_dims: Tuple[int] = (),
gamma: float = None,
separate_losses: bool = False,
advantage_key: str = None,
Expand Down Expand Up @@ -381,6 +387,8 @@ def __init__(
self.critic_coef = None
self.loss_critic_type = loss_critic_type
self.normalize_advantage = normalize_advantage
self.normalize_advantage_exclude_dims = normalize_advantage_exclude_dims

if gamma is not None:
raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR)
self._set_deprecated_ctor_keys(
Expand Down Expand Up @@ -606,9 +614,16 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
)
advantage = tensordict.get(self.tensor_keys.advantage)
if self.normalize_advantage and advantage.numel() > 1:
loc = advantage.mean()
scale = advantage.std().clamp_min(1e-6)
advantage = (advantage - loc) / scale
if advantage.numel() > tensordict.batch_size.numel() and not len(
self.normalize_advantage_exclude_dims
):
warnings.warn(
"You requested advantage normalization and the advantage key has more dimensions"
" than the tensordict batch. Make sure to pass `normalize_advantage_exclude_dims` "
"if you want to keep any dimension independent while computing normalization statistics. "
"If you are working in multi-agent/multi-objective settings this is highly suggested."
)
advantage = _standardize(advantage, self.normalize_advantage_exclude_dims)

log_weight, dist, kl_approx = self._log_weight(tensordict)
if is_tensor_collection(log_weight):
Expand Down Expand Up @@ -711,6 +726,9 @@ class ClipPPOLoss(PPOLoss):
Can be one of "l1", "l2" or "smooth_l1". Defaults to ``"smooth_l1"``.
normalize_advantage (bool, optional): if ``True``, the advantage will be normalized
before being used. Defaults to ``False``.
normalize_advantage_exclude_dims (Tuple[int], optional): dimensions to exclude from the advantage
standardization. Negative dimensions are valid. This is useful in multiagent (or multiobjective) settings
where the agent (or objective) dimension may be excluded from the reductions. Default: ().
separate_losses (bool, optional): if ``True``, shared parameters between
policy and critic will only be trained on the policy loss.
Defaults to ``False``, i.e., gradients are propagated to shared
Expand Down Expand Up @@ -802,6 +820,7 @@ def __init__(
critic_coef: float = 1.0,
loss_critic_type: str = "smooth_l1",
normalize_advantage: bool = False,
normalize_advantage_exclude_dims: Tuple[int] = (),
gamma: float = None,
separate_losses: bool = False,
reduction: str = None,
Expand All @@ -821,6 +840,7 @@ def __init__(
critic_coef=critic_coef,
loss_critic_type=loss_critic_type,
normalize_advantage=normalize_advantage,
normalize_advantage_exclude_dims=normalize_advantage_exclude_dims,
gamma=gamma,
separate_losses=separate_losses,
reduction=reduction,
Expand Down Expand Up @@ -871,9 +891,16 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
)
advantage = tensordict.get(self.tensor_keys.advantage)
if self.normalize_advantage and advantage.numel() > 1:
loc = advantage.mean()
scale = advantage.std().clamp_min(1e-6)
advantage = (advantage - loc) / scale
if advantage.numel() > tensordict.batch_size.numel() and not len(
self.normalize_advantage_exclude_dims
):
warnings.warn(
"You requested advantage normalization and the advantage key has more dimensions"
" than the tensordict batch. Make sure to pass `normalize_advantage_exclude_dims` "
"if you want to keep any dimension independent while computing normalization statistics. "
"If you are working in multi-agent/multi-objective settings this is highly suggested."
)
advantage = _standardize(advantage, self.normalize_advantage_exclude_dims)

log_weight, dist, kl_approx = self._log_weight(tensordict)
# ESS for logging
Expand Down Expand Up @@ -955,6 +982,9 @@ class KLPENPPOLoss(PPOLoss):
Can be one of "l1", "l2" or "smooth_l1". Defaults to ``"smooth_l1"``.
normalize_advantage (bool, optional): if ``True``, the advantage will be normalized
before being used. Defaults to ``False``.
normalize_advantage_exclude_dims (Tuple[int], optional): dimensions to exclude from the advantage
standardization. Negative dimensions are valid. This is useful in multiagent (or multiobjective) settings
where the agent (or objective) dimension may be excluded from the reductions. Default: ().
separate_losses (bool, optional): if ``True``, shared parameters between
policy and critic will only be trained on the policy loss.
Defaults to ``False``, i.e., gradients are propagated to shared
Expand Down Expand Up @@ -1048,6 +1078,7 @@ def __init__(
critic_coef: float = 1.0,
loss_critic_type: str = "smooth_l1",
normalize_advantage: bool = False,
normalize_advantage_exclude_dims: Tuple[int] = (),
gamma: float = None,
separate_losses: bool = False,
reduction: str = None,
Expand All @@ -1063,6 +1094,7 @@ def __init__(
critic_coef=critic_coef,
loss_critic_type=loss_critic_type,
normalize_advantage=normalize_advantage,
normalize_advantage_exclude_dims=normalize_advantage_exclude_dims,
gamma=gamma,
separate_losses=separate_losses,
reduction=reduction,
Expand Down Expand Up @@ -1151,9 +1183,17 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict:
)
advantage = tensordict_copy.get(self.tensor_keys.advantage)
if self.normalize_advantage and advantage.numel() > 1:
loc = advantage.mean()
scale = advantage.std().clamp_min(1e-6)
advantage = (advantage - loc) / scale
if advantage.numel() > tensordict.batch_size.numel() and not len(
self.normalize_advantage_exclude_dims
):
warnings.warn(
"You requested advantage normalization and the advantage key has more dimensions"
" than the tensordict batch. Make sure to pass `normalize_advantage_exclude_dims` "
"if you want to keep any dimension independent while computing normalization statistics. "
"If you are working in multi-agent/multi-objective settings this is highly suggested."
)
advantage = _standardize(advantage, self.normalize_advantage_exclude_dims)

log_weight, dist, kl_approx = self._log_weight(tensordict_copy)
neg_loss = log_weight.exp() * advantage

Expand Down
Loading