From 3d1e978d6a882f2653650436105f9d3d323c70f1 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Thu, 26 Dec 2024 16:11:30 +0100 Subject: [PATCH 1/9] multiagent norm --- torchrl/_utils.py | 54 ++++++++++++++++++++++++++++++++++++++- torchrl/objectives/ppo.py | 29 +++++++++++++-------- 2 files changed, 72 insertions(+), 11 deletions(-) diff --git a/torchrl/_utils.py b/torchrl/_utils.py index 6a2f80aeffb..ddd92d58341 100644 --- a/torchrl/_utils.py +++ b/torchrl/_utils.py @@ -24,7 +24,7 @@ from distutils.util import strtobool from functools import wraps from importlib import import_module -from typing import Any, Callable, cast, Dict, TypeVar, Union +from typing import Any, Callable, cast, Dict, Sequence, TypeVar, Union import numpy as np import torch @@ -872,6 +872,58 @@ def set_mode(self, type: Any | None) -> None: self._mode = type +def _standardize(input, exclude_dims: Sequence[int] = (), mean=None, std=None): + """Standardizes the input tensor with the possibility of excluding specific dims from the statistics. + + Useful when processing multi-agent data to keep the agent dimensions independent. + + Args: + input (Tensor): the input tensor to be standardized. + exclude_dims (Sequence[int]): dimensions to exclude from the statistics, can be negative. Default: (). + mean (Tensor): a mean to be used for standardization. Must be of shape broadcastable to input. Default: None. + std (Tensor): a standard deviation to be used for standardization. Must be of shape broadcastable to input. Default: None. + + """ + input_shape = input.shape + exclude_dims = [ + d if d >= 0 else d + len(input_shape) for d in exclude_dims + ] # Make negative dims positive + + if len(set(exclude_dims)) != len(exclude_dims): + raise ValueError("Exclude dims has repeating elements") + if any(dim < 0 or dim >= len(input_shape) for dim in exclude_dims): + raise ValueError( + f"exclude_dims provided outside bounds for input of shape={input_shape}" + ) + if len(exclude_dims) == len(input_shape): + warnings.warn( + "standardize called but all dims were excluded from the statistics, returning unprocessed input" + ) + return input + + # Put all excluded dims in the beginning + permutation = list(range(len(input_shape))) + for dim in exclude_dims: + permutation.insert(0, permutation.pop(permutation.index(dim))) + permuted_input = input.permute(*permutation) + normalized_shape_len = len(input_shape) - len(exclude_dims) + + if mean is None: + mean = torch.mean( + permuted_input, keepdim=True, dim=tuple(range(-normalized_shape_len, 0)) + ) + if std is None: + std = torch.std( + permuted_input, keepdim=True, dim=tuple(range(-normalized_shape_len, 0)) + ) + output = (permuted_input - mean) / std.clamp_min(1e-6) + + # Reverse permutation + inv_permutation = torch.argsort(torch.LongTensor(permutation)).tolist() + output = torch.permute(output, inv_permutation) + return output + + @wraps(torch.compile) def compile_with_warmup(*args, warmup: int = 1, **kwargs): """Compile a model with warm-up. diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index eb9a916dfc1..36068be2aaa 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -8,7 +8,7 @@ from copy import deepcopy from dataclasses import dataclass -from typing import Tuple +from typing import Sequence, Tuple import torch from tensordict import ( @@ -27,6 +27,7 @@ from tensordict.utils import NestedKey from torch import distributions as d +from torchrl._utils import _standardize from torchrl.objectives.common import LossModule from torchrl.objectives.utils import ( @@ -87,6 +88,8 @@ class PPOLoss(LossModule): Can be one of "l1", "l2" or "smooth_l1". Defaults to ``"smooth_l1"``. normalize_advantage (bool, optional): if ``True``, the advantage will be normalized before being used. Defaults to ``False``. + normalize_advantage_exclude_dims (Sequence[int], optional): dimensions to exclude from the advantage + standardization, can be negative. Useful in multiagent settings to exlude the agent dimension. Default: (). separate_losses (bool, optional): if ``True``, shared parameters between policy and critic will only be trained on the policy loss. Defaults to ``False``, i.e., gradients are propagated to shared @@ -311,6 +314,7 @@ def __init__( critic_coef: float = 1.0, loss_critic_type: str = "smooth_l1", normalize_advantage: bool = False, + normalize_advantage_exclude_dims: Sequence[int] = (), gamma: float = None, separate_losses: bool = False, advantage_key: str = None, @@ -381,6 +385,8 @@ def __init__( self.critic_coef = None self.loss_critic_type = loss_critic_type self.normalize_advantage = normalize_advantage + self.normalize_advantage_exclude_dims = normalize_advantage_exclude_dims + if gamma is not None: raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR) self._set_deprecated_ctor_keys( @@ -606,9 +612,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: ) advantage = tensordict.get(self.tensor_keys.advantage) if self.normalize_advantage and advantage.numel() > 1: - loc = advantage.mean() - scale = advantage.std().clamp_min(1e-6) - advantage = (advantage - loc) / scale + advantage = _standardize(advantage, self.normalize_advantage_exclude_dims) log_weight, dist, kl_approx = self._log_weight(tensordict) if is_tensor_collection(log_weight): @@ -711,6 +715,8 @@ class ClipPPOLoss(PPOLoss): Can be one of "l1", "l2" or "smooth_l1". Defaults to ``"smooth_l1"``. normalize_advantage (bool, optional): if ``True``, the advantage will be normalized before being used. Defaults to ``False``. + normalize_advantage_exclude_dims (Sequence[int], optional): dimensions to exclude from the advantage + standardization, can be negative. Useful in multiagent settings to exlude the agent dimension. Default: (). separate_losses (bool, optional): if ``True``, shared parameters between policy and critic will only be trained on the policy loss. Defaults to ``False``, i.e., gradients are propagated to shared @@ -802,6 +808,7 @@ def __init__( critic_coef: float = 1.0, loss_critic_type: str = "smooth_l1", normalize_advantage: bool = False, + normalize_advantage_exclude_dims: Sequence[int] = (), gamma: float = None, separate_losses: bool = False, reduction: str = None, @@ -821,6 +828,7 @@ def __init__( critic_coef=critic_coef, loss_critic_type=loss_critic_type, normalize_advantage=normalize_advantage, + normalize_advantage_exclude_dims=normalize_advantage_exclude_dims, gamma=gamma, separate_losses=separate_losses, reduction=reduction, @@ -871,9 +879,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: ) advantage = tensordict.get(self.tensor_keys.advantage) if self.normalize_advantage and advantage.numel() > 1: - loc = advantage.mean() - scale = advantage.std().clamp_min(1e-6) - advantage = (advantage - loc) / scale + advantage = _standardize(advantage, self.normalize_advantage_exclude_dims) log_weight, dist, kl_approx = self._log_weight(tensordict) # ESS for logging @@ -955,6 +961,8 @@ class KLPENPPOLoss(PPOLoss): Can be one of "l1", "l2" or "smooth_l1". Defaults to ``"smooth_l1"``. normalize_advantage (bool, optional): if ``True``, the advantage will be normalized before being used. Defaults to ``False``. + normalize_advantage_exclude_dims (Sequence[int], optional): dimensions to exclude from the advantage + standardization, can be negative. Useful in multiagent settings to exlude the agent dimension. Default: (). separate_losses (bool, optional): if ``True``, shared parameters between policy and critic will only be trained on the policy loss. Defaults to ``False``, i.e., gradients are propagated to shared @@ -1048,6 +1056,7 @@ def __init__( critic_coef: float = 1.0, loss_critic_type: str = "smooth_l1", normalize_advantage: bool = False, + normalize_advantage_exclude_dims: Sequence[int] = (), gamma: float = None, separate_losses: bool = False, reduction: str = None, @@ -1063,6 +1072,7 @@ def __init__( critic_coef=critic_coef, loss_critic_type=loss_critic_type, normalize_advantage=normalize_advantage, + normalize_advantage_exclude_dims=normalize_advantage_exclude_dims, gamma=gamma, separate_losses=separate_losses, reduction=reduction, @@ -1151,9 +1161,8 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict: ) advantage = tensordict_copy.get(self.tensor_keys.advantage) if self.normalize_advantage and advantage.numel() > 1: - loc = advantage.mean() - scale = advantage.std().clamp_min(1e-6) - advantage = (advantage - loc) / scale + advantage = _standardize(advantage, self.normalize_advantage_exclude_dims) + log_weight, dist, kl_approx = self._log_weight(tensordict_copy) neg_loss = log_weight.exp() * advantage From 9cfe213c7b9b2c51e54c64e5e4b386efa5b09ee9 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Fri, 10 Jan 2025 09:37:59 +0100 Subject: [PATCH 2/9] review comments --- torchrl/_utils.py | 35 ++++++++++++++++++++++++----------- torchrl/objectives/ppo.py | 23 +++++++++++++---------- 2 files changed, 37 insertions(+), 21 deletions(-) diff --git a/torchrl/_utils.py b/torchrl/_utils.py index ddd92d58341..50be67768e9 100644 --- a/torchrl/_utils.py +++ b/torchrl/_utils.py @@ -24,7 +24,7 @@ from distutils.util import strtobool from functools import wraps from importlib import import_module -from typing import Any, Callable, cast, Dict, Sequence, TypeVar, Union +from typing import Any, Callable, cast, Dict, Tuple, TypeVar, Union import numpy as np import torch @@ -872,7 +872,9 @@ def set_mode(self, type: Any | None) -> None: self._mode = type -def _standardize(input, exclude_dims: Sequence[int] = (), mean=None, std=None): +def _standardize( + input, exclude_dims: Tuple[int] = (), mean=None, std=None, eps: float = None +): """Standardizes the input tensor with the possibility of excluding specific dims from the statistics. Useful when processing multi-agent data to keep the agent dimensions independent. @@ -882,8 +884,12 @@ def _standardize(input, exclude_dims: Sequence[int] = (), mean=None, std=None): exclude_dims (Sequence[int]): dimensions to exclude from the statistics, can be negative. Default: (). mean (Tensor): a mean to be used for standardization. Must be of shape broadcastable to input. Default: None. std (Tensor): a standard deviation to be used for standardization. Must be of shape broadcastable to input. Default: None. + eps (float): epsilon to be used for numerical stability. Default: float32 resolution. """ + if eps is None: + eps = torch.finfo(torch.float.dtype).resolution + input_shape = input.shape exclude_dims = [ d if d >= 0 else d + len(input_shape) for d in exclude_dims @@ -893,7 +899,7 @@ def _standardize(input, exclude_dims: Sequence[int] = (), mean=None, std=None): raise ValueError("Exclude dims has repeating elements") if any(dim < 0 or dim >= len(input_shape) for dim in exclude_dims): raise ValueError( - f"exclude_dims provided outside bounds for input of shape={input_shape}" + f"exclude_dims={exclude_dims} provided outside bounds for input of shape={input_shape}" ) if len(exclude_dims) == len(input_shape): warnings.warn( @@ -901,11 +907,14 @@ def _standardize(input, exclude_dims: Sequence[int] = (), mean=None, std=None): ) return input - # Put all excluded dims in the beginning - permutation = list(range(len(input_shape))) - for dim in exclude_dims: - permutation.insert(0, permutation.pop(permutation.index(dim))) - permuted_input = input.permute(*permutation) + if len(exclude_dims): + # Put all excluded dims in the beginning + permutation = list(range(len(input_shape))) + for dim in exclude_dims: + permutation.insert(0, permutation.pop(permutation.index(dim))) + permuted_input = input.permute(*permutation) + else: + permuted_input = input normalized_shape_len = len(input_shape) - len(exclude_dims) if mean is None: @@ -916,11 +925,15 @@ def _standardize(input, exclude_dims: Sequence[int] = (), mean=None, std=None): std = torch.std( permuted_input, keepdim=True, dim=tuple(range(-normalized_shape_len, 0)) ) - output = (permuted_input - mean) / std.clamp_min(1e-6) + output = (permuted_input - mean) / std.clamp_min(eps) # Reverse permutation - inv_permutation = torch.argsort(torch.LongTensor(permutation)).tolist() - output = torch.permute(output, inv_permutation) + if len(exclude_dims): + inv_permutation = torch.argsort( + torch.tensor(permutation, dtype=torch.long, device=input.device) + ).tolist() + output = torch.permute(output, inv_permutation) + return output diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index 36068be2aaa..5e5d4ea006e 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -8,7 +8,7 @@ from copy import deepcopy from dataclasses import dataclass -from typing import Sequence, Tuple +from typing import Tuple import torch from tensordict import ( @@ -88,8 +88,9 @@ class PPOLoss(LossModule): Can be one of "l1", "l2" or "smooth_l1". Defaults to ``"smooth_l1"``. normalize_advantage (bool, optional): if ``True``, the advantage will be normalized before being used. Defaults to ``False``. - normalize_advantage_exclude_dims (Sequence[int], optional): dimensions to exclude from the advantage - standardization, can be negative. Useful in multiagent settings to exlude the agent dimension. Default: (). + normalize_advantage_exclude_dims (Tuple[int], optional): dimensions to exclude from the advantage + standardization. Negative dimensions are valid. This is useful in multiagent (or multiobjective) settings + where the agent (or objective) dimension may be excluded from the reductions. Default: (). separate_losses (bool, optional): if ``True``, shared parameters between policy and critic will only be trained on the policy loss. Defaults to ``False``, i.e., gradients are propagated to shared @@ -314,7 +315,7 @@ def __init__( critic_coef: float = 1.0, loss_critic_type: str = "smooth_l1", normalize_advantage: bool = False, - normalize_advantage_exclude_dims: Sequence[int] = (), + normalize_advantage_exclude_dims: Tuple[int] = (), gamma: float = None, separate_losses: bool = False, advantage_key: str = None, @@ -715,8 +716,9 @@ class ClipPPOLoss(PPOLoss): Can be one of "l1", "l2" or "smooth_l1". Defaults to ``"smooth_l1"``. normalize_advantage (bool, optional): if ``True``, the advantage will be normalized before being used. Defaults to ``False``. - normalize_advantage_exclude_dims (Sequence[int], optional): dimensions to exclude from the advantage - standardization, can be negative. Useful in multiagent settings to exlude the agent dimension. Default: (). + normalize_advantage_exclude_dims (Tuple[int], optional): dimensions to exclude from the advantage + standardization. Negative dimensions are valid. This is useful in multiagent (or multiobjective) settings + where the agent (or objective) dimension may be excluded from the reductions. Default: (). separate_losses (bool, optional): if ``True``, shared parameters between policy and critic will only be trained on the policy loss. Defaults to ``False``, i.e., gradients are propagated to shared @@ -808,7 +810,7 @@ def __init__( critic_coef: float = 1.0, loss_critic_type: str = "smooth_l1", normalize_advantage: bool = False, - normalize_advantage_exclude_dims: Sequence[int] = (), + normalize_advantage_exclude_dims: Tuple[int] = (), gamma: float = None, separate_losses: bool = False, reduction: str = None, @@ -961,8 +963,9 @@ class KLPENPPOLoss(PPOLoss): Can be one of "l1", "l2" or "smooth_l1". Defaults to ``"smooth_l1"``. normalize_advantage (bool, optional): if ``True``, the advantage will be normalized before being used. Defaults to ``False``. - normalize_advantage_exclude_dims (Sequence[int], optional): dimensions to exclude from the advantage - standardization, can be negative. Useful in multiagent settings to exlude the agent dimension. Default: (). + normalize_advantage_exclude_dims (Tuple[int], optional): dimensions to exclude from the advantage + standardization. Negative dimensions are valid. This is useful in multiagent (or multiobjective) settings + where the agent (or objective) dimension may be excluded from the reductions. Default: (). separate_losses (bool, optional): if ``True``, shared parameters between policy and critic will only be trained on the policy loss. Defaults to ``False``, i.e., gradients are propagated to shared @@ -1056,7 +1059,7 @@ def __init__( critic_coef: float = 1.0, loss_critic_type: str = "smooth_l1", normalize_advantage: bool = False, - normalize_advantage_exclude_dims: Sequence[int] = (), + normalize_advantage_exclude_dims: Tuple[int] = (), gamma: float = None, separate_losses: bool = False, reduction: str = None, From fbe20d4dc5e1d2498771a102bb88ace00811bb02 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Fri, 10 Jan 2025 09:45:44 +0100 Subject: [PATCH 3/9] warning --- torchrl/objectives/ppo.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index 5e5d4ea006e..3d1b3bd5088 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -47,6 +47,7 @@ TDLambdaEstimator, VTrace, ) +from yaml import warnings class PPOLoss(LossModule): @@ -613,6 +614,15 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: ) advantage = tensordict.get(self.tensor_keys.advantage) if self.normalize_advantage and advantage.numel() > 1: + if advantage.numel() > tensordict.batch_size.numel() and not len( + self.normalize_advantage_exclude_dims + ): + warnings.warn( + "You requested advantage normalization and the advantage key has more dimensions" + " than the tensordict batch. Make sure to pass `normalize_advantage_exclude_dims` " + "if you want to keep any dimension independent while computing normalization statistics. " + "If you are working in multi-agent/multi-objective settings this is highly suggested." + ) advantage = _standardize(advantage, self.normalize_advantage_exclude_dims) log_weight, dist, kl_approx = self._log_weight(tensordict) @@ -881,6 +891,15 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: ) advantage = tensordict.get(self.tensor_keys.advantage) if self.normalize_advantage and advantage.numel() > 1: + if advantage.numel() > tensordict.batch_size.numel() and not len( + self.normalize_advantage_exclude_dims + ): + warnings.warn( + "You requested advantage normalization and the advantage key has more dimensions" + " than the tensordict batch. Make sure to pass `normalize_advantage_exclude_dims` " + "if you want to keep any dimension independent while computing normalization statistics. " + "If you are working in multi-agent/multi-objective settings this is highly suggested." + ) advantage = _standardize(advantage, self.normalize_advantage_exclude_dims) log_weight, dist, kl_approx = self._log_weight(tensordict) @@ -1164,6 +1183,15 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict: ) advantage = tensordict_copy.get(self.tensor_keys.advantage) if self.normalize_advantage and advantage.numel() > 1: + if advantage.numel() > tensordict.batch_size.numel() and not len( + self.normalize_advantage_exclude_dims + ): + warnings.warn( + "You requested advantage normalization and the advantage key has more dimensions" + " than the tensordict batch. Make sure to pass `normalize_advantage_exclude_dims` " + "if you want to keep any dimension independent while computing normalization statistics. " + "If you are working in multi-agent/multi-objective settings this is highly suggested." + ) advantage = _standardize(advantage, self.normalize_advantage_exclude_dims) log_weight, dist, kl_approx = self._log_weight(tensordict_copy) From 1f9e7dfd21a54d9032e86e2159539c5ba81ba2a2 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Mon, 13 Jan 2025 20:43:40 +0100 Subject: [PATCH 4/9] special case for no dim exclusion --- torchrl/_utils.py | 37 +++++++++++++++++++++---------------- 1 file changed, 21 insertions(+), 16 deletions(-) diff --git a/torchrl/_utils.py b/torchrl/_utils.py index 50be67768e9..cc90ce22dce 100644 --- a/torchrl/_utils.py +++ b/torchrl/_utils.py @@ -890,32 +890,38 @@ def _standardize( if eps is None: eps = torch.finfo(torch.float.dtype).resolution + len_exclude_dims = len(exclude_dims) + if not len_exclude_dims: + if mean is None: + mean = input.mean() + if std is None: + std = input.std() + return (input - mean) / std.clamp_min(eps) + input_shape = input.shape exclude_dims = [ d if d >= 0 else d + len(input_shape) for d in exclude_dims ] # Make negative dims positive - if len(set(exclude_dims)) != len(exclude_dims): + if len(set(exclude_dims)) != len_exclude_dims: raise ValueError("Exclude dims has repeating elements") if any(dim < 0 or dim >= len(input_shape) for dim in exclude_dims): raise ValueError( f"exclude_dims={exclude_dims} provided outside bounds for input of shape={input_shape}" ) - if len(exclude_dims) == len(input_shape): + if len_exclude_dims == len(input_shape): warnings.warn( "standardize called but all dims were excluded from the statistics, returning unprocessed input" ) return input - if len(exclude_dims): - # Put all excluded dims in the beginning - permutation = list(range(len(input_shape))) - for dim in exclude_dims: - permutation.insert(0, permutation.pop(permutation.index(dim))) - permuted_input = input.permute(*permutation) - else: - permuted_input = input - normalized_shape_len = len(input_shape) - len(exclude_dims) + # Put all excluded dims in the beginning + permutation = list(range(len(input_shape))) + for dim in exclude_dims: + permutation.insert(0, permutation.pop(permutation.index(dim))) + permuted_input = input.permute(*permutation) + + normalized_shape_len = len(input_shape) - len_exclude_dims if mean is None: mean = torch.mean( @@ -928,11 +934,10 @@ def _standardize( output = (permuted_input - mean) / std.clamp_min(eps) # Reverse permutation - if len(exclude_dims): - inv_permutation = torch.argsort( - torch.tensor(permutation, dtype=torch.long, device=input.device) - ).tolist() - output = torch.permute(output, inv_permutation) + inv_permutation = torch.argsort( + torch.tensor(permutation, dtype=torch.long, device=input.device) + ).tolist() + output = torch.permute(output, inv_permutation) return output From 2917287f30f7bebc142e24d9951df0051ca5bc64 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Mon, 13 Jan 2025 20:44:35 +0100 Subject: [PATCH 5/9] typo --- torchrl/_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchrl/_utils.py b/torchrl/_utils.py index cc90ce22dce..672eb9f512d 100644 --- a/torchrl/_utils.py +++ b/torchrl/_utils.py @@ -911,7 +911,7 @@ def _standardize( ) if len_exclude_dims == len(input_shape): warnings.warn( - "standardize called but all dims were excluded from the statistics, returning unprocessed input" + "_standardize called but all dims were excluded from the statistics, returning unprocessed input" ) return input From 926a08d1eeff4110336e30551cc1599c4e3c344a Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Mon, 13 Jan 2025 20:47:20 +0100 Subject: [PATCH 6/9] better type suggestions --- torchrl/_utils.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/torchrl/_utils.py b/torchrl/_utils.py index 672eb9f512d..1d28e71022a 100644 --- a/torchrl/_utils.py +++ b/torchrl/_utils.py @@ -32,7 +32,7 @@ from tensordict import unravel_key from tensordict.utils import NestedKey -from torch import multiprocessing as mp +from torch import multiprocessing as mp, Tensor try: from torch.compiler import is_compiling @@ -873,7 +873,11 @@ def set_mode(self, type: Any | None) -> None: def _standardize( - input, exclude_dims: Tuple[int] = (), mean=None, std=None, eps: float = None + input: Tensor, + exclude_dims: Tuple[int] = (), + mean: Tensor | None = None, + std: Tensor | None = None, + eps: float | None = None, ): """Standardizes the input tensor with the possibility of excluding specific dims from the statistics. From 1b428f0d69ff279319e13172df879d0f1ed6b166 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Wed, 15 Jan 2025 11:03:57 +0100 Subject: [PATCH 7/9] use included_dims instead of excluded --- torchrl/_utils.py | 27 ++++----------------------- 1 file changed, 4 insertions(+), 23 deletions(-) diff --git a/torchrl/_utils.py b/torchrl/_utils.py index 1d28e71022a..56f5c32b3af 100644 --- a/torchrl/_utils.py +++ b/torchrl/_utils.py @@ -919,31 +919,12 @@ def _standardize( ) return input - # Put all excluded dims in the beginning - permutation = list(range(len(input_shape))) - for dim in exclude_dims: - permutation.insert(0, permutation.pop(permutation.index(dim))) - permuted_input = input.permute(*permutation) - - normalized_shape_len = len(input_shape) - len_exclude_dims - + included_dims = tuple(d for d in range(len(input_shape)) if d not in exclude_dims) if mean is None: - mean = torch.mean( - permuted_input, keepdim=True, dim=tuple(range(-normalized_shape_len, 0)) - ) + mean = torch.mean(input, keepdim=True, dim=included_dims) if std is None: - std = torch.std( - permuted_input, keepdim=True, dim=tuple(range(-normalized_shape_len, 0)) - ) - output = (permuted_input - mean) / std.clamp_min(eps) - - # Reverse permutation - inv_permutation = torch.argsort( - torch.tensor(permutation, dtype=torch.long, device=input.device) - ).tolist() - output = torch.permute(output, inv_permutation) - - return output + std = torch.std(input, keepdim=True, dim=included_dims) + return (input - mean) / std.clamp_min(eps) @wraps(torch.compile) From 278a6c0be176c6984e0a12b6d8abd10c1f8b5901 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Wed, 15 Jan 2025 11:04:27 +0100 Subject: [PATCH 8/9] doc --- torchrl/_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchrl/_utils.py b/torchrl/_utils.py index 56f5c32b3af..cfc54ad207c 100644 --- a/torchrl/_utils.py +++ b/torchrl/_utils.py @@ -885,7 +885,7 @@ def _standardize( Args: input (Tensor): the input tensor to be standardized. - exclude_dims (Sequence[int]): dimensions to exclude from the statistics, can be negative. Default: (). + exclude_dims (Tuple[int]): dimensions to exclude from the statistics, can be negative. Default: (). mean (Tensor): a mean to be used for standardization. Must be of shape broadcastable to input. Default: None. std (Tensor): a standard deviation to be used for standardization. Must be of shape broadcastable to input. Default: None. eps (float): epsilon to be used for numerical stability. Default: float32 resolution. From 121bbc16b85f92d37b46f1c9dae6efd383871ba8 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 15 Jan 2025 16:03:24 +0000 Subject: [PATCH 9/9] add tests and fix bugs --- test/test_cost.py | 14 +++++++++++++- torchrl/_utils.py | 11 ++++++++++- 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/test/test_cost.py b/test/test_cost.py index 1f191e41db6..a0283e0e276 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -19,7 +19,6 @@ import torch from packaging import version, version as pack_version - from tensordict import assert_allclose_td, TensorDict, TensorDictBase from tensordict._C import unravel_keys from tensordict.nn import ( @@ -38,6 +37,7 @@ from tensordict.nn.utils import Buffer from tensordict.utils import unravel_key from torch import autograd, nn +from torchrl._utils import _standardize from torchrl.data import Bounded, Categorical, Composite, MultiOneHot, OneHot, Unbounded from torchrl.data.postprocs.postprocs import MultiStep from torchrl.envs.model_based.dreamer import DreamerEnv @@ -15848,6 +15848,18 @@ class _AcceptedKeys: class TestUtils: + def test_standardization(self): + t = torch.arange(3 * 4 * 5 * 6, dtype=torch.float32).view(3, 4, 5, 6) + std_t0 = _standardize(t, exclude_dims=(1, 3)) + std_t1 = (t - t.mean((0, 2), keepdim=True)) / t.std((0, 2), keepdim=True).clamp( + 1 - 6 + ) + torch.testing.assert_close(std_t0, std_t1) + std_t = _standardize(t, (), -1, 2) + torch.testing.assert_close(std_t, (t + 1) / 2) + std_t = _standardize(t, ()) + torch.testing.assert_close(std_t, (t - t.mean()) / t.std()) + @pytest.mark.parametrize("B", [None, (1, ), (4, ), (2, 2, ), (1, 2, 8, )]) # fmt: skip @pytest.mark.parametrize("T", [1, 10]) @pytest.mark.parametrize("device", get_default_devices()) diff --git a/torchrl/_utils.py b/torchrl/_utils.py index cfc54ad207c..f999fa96c1d 100644 --- a/torchrl/_utils.py +++ b/torchrl/_utils.py @@ -892,14 +892,23 @@ def _standardize( """ if eps is None: - eps = torch.finfo(torch.float.dtype).resolution + if input.dtype.is_floating_point: + eps = torch.finfo(torch.float).resolution + else: + eps = 1e-6 len_exclude_dims = len(exclude_dims) if not len_exclude_dims: if mean is None: mean = input.mean() + else: + # Assume dtypes are compatible + mean = torch.as_tensor(mean, device=input.device) if std is None: std = input.std() + else: + # Assume dtypes are compatible + std = torch.as_tensor(std, device=input.device) return (input - mean) / std.clamp_min(eps) input_shape = input.shape