Skip to content

[BUG] GAE parameters (gamma, lmbda) seemed to get changed by ClipPPOLoss, advantage module does not calculate loss_critic #2462

Open
@therealjoker4u

Description

@therealjoker4u

Environment

OS: Windows 11
Python : CPython 3.10.14
Torchrl Version : 0.5.0
PyTorch Version : 2.4.1+cu124
Gym Environment: A custom subclass of EnvBase (from torchrl.envs)

The project I'm working on is relatively complex, so I only mention parts of code that I know are related to the bug that I mention below.
Here's the definition of actor, value (critic), advantage, and loss module.

import torch
from torchrl.modules import ProbabilisticActor, ValueOperator, OneHotCategorical
from torchrl.objectives import ClipPPOLoss
from torchrl.objectives.value import GAE

self.action_spec = DiscreteTensorSpec(3, dtype=torch.int8)

# Actor
_actor = TensorDictModule(self.agent, in_keys=self.agent.in_keys,
                          out_keys=self.agent.out_keys).to(self.agent.device)

self.actor_module = ProbabilisticActor(
    _actor,
    in_keys=self.agent.out_keys,
    spec=self.action_spec,
    distribution_class=OneHotCategorical,
    return_log_prob=True,
)

# Critic
self.value_net = MyValueNetwork(device=agent.device)
self.value_module = ValueOperator(
    self.value_net, in_keys=self.value_net.in_keys, 
)

# Advantage
self.advantage_module = GAE(value_network=self.value_module,
      gamma=self.advantage_gamma,
      lmbda=self.advantage_lmbda,
      differentiable=True,
      average_gae=True,
      device=self.value_module.device,
)

# Loss
entropy_eps = 0.001
self.loss_module = ClipPPOLoss(
    actor_network=self.actor_module,
    critic_network=self.value_module,
    clip_epsilon=(0.2, ),
    entropy_bonus=bool(entropy_eps),
    entropy_coef=entropy_eps,
    critic_coef=1.0,
    loss_critic_type="smooth_l1",
)

Training loop

My training loop catches the batched data from a MultiSyncDataCollector, and adds it to a replay buffer with a LazyTensorStorage storage, and after that it samples and passes the sample to the _optimize_policy function:

def _optimize_sample(self, sample: TensorDict):
  self.actor_optimizer.zero_grad()
  self.critic_optimizer.zero_grad()
  
  if not self.value_net.weights_initialized:
      self.value_net(sample["observation"])
      self.value_net.weights_initialized = True
  
  self.actor_module(sample)
  sample["sample_log_prob"] = sample["sample_log_prob"].detach()
  
  self.advantage_module(sample)
  
  loss_vals = self.loss_module(sample)
  total_loss = loss_vals["loss_entropy"] + \
      loss_vals["loss_objective"] + loss_vals["loss_critic"]
  
  total_loss.backward()
  
  torch.nn.utils.clip_grad_norm_(
      self.actor_module.parameters(), max_norm=0.1)
  torch.nn.utils.clip_grad_norm_(
      self.value_module.parameters(), max_norm=1.0)
  
  self.actor_optimizer.step()
  self.critic_optimizer.step()
  
  return loss_vals

In the code above I got the error below when it called self.actor_module(sample):

loss_vals = self.loss_module(sample)
  File "C:\Users\ARX\anaconda3\envs\agent\lib\site-packages\torch\nn\modules\module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "C:\Users\ARX\anaconda3\envs\agent\lib\site-packages\torch\nn\modules\module.py", line 1603, in _call_impl
    result = forward_call(*args, **kwargs)
  File "C:\Users\ARX\anaconda3\envs\agent\lib\site-packages\torchrl\objectives\common.py", line 39, in new_forward
    return func(self, *args, **kwargs)
  File "C:\Users\ARX\anaconda3\envs\agent\lib\site-packages\tensordict\nn\common.py", line 297, in wrapper
    return func(_self, tensordict, *args, **kwargs)
  File "C:\Users\ARX\anaconda3\envs\agent\lib\site-packages\torchrl\objectives\ppo.py", line 817, in forward
    log_weight, dist, kl_approx = self._log_weight(tensordict)
  File "C:\Users\ARX\anaconda3\envs\agent\lib\site-packages\torchrl\objectives\ppo.py", line 473, in _log_weight
    raise RuntimeError("tensordict prev_log_prob requires grad.")
RuntimeError: tensordict prev_log_prob requires grad.

So I added sample["sample_log_prob"] = sample["sample_log_prob"].detach() to detach sample_log_prob from the computation graph. and the issue was solved.

At this stage the model seems to converge, as objective and critic loss is minimizing:
Figure 1 - Objective/Policy loss (Exponentially moving average interval 100):
Figure_1

Figure 2 - Critic loss:
Figure_2_critic

The main issue

At this point apparently, everything is ok, but the main issue occurs when I connect the actor (policy) module to the collector, to collect data based on the current policy (not a random choice of actions):

train_kwargs["policy_device"] = self.agent.device
train_kwargs["policy"] = self.actor_module
my_collector = MultiSyncDataCollector(**train_kwargs)

And when I run it, I get the error below (thrown inside self.advantage_module(sample)):

self.advantage_module(sample)
  File "C:\Users\ARX\anaconda3\envs\agent\lib\site-packages\torch\nn\modules\module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "C:\Users\ARX\anaconda3\envs\agent\lib\site-packages\torch\nn\modules\module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "C:\Users\ARX\anaconda3\envs\agent\lib\site-packages\torchrl\objectives\value\advantages.py", line 68, in new_func
    return fun(self, *args, **kwargs)
  File "C:\Users\ARX\anaconda3\envs\agent\lib\site-packages\torchrl\objectives\value\advantages.py", line 57, in new_fun
    return fun(self, *args, **kwargs)
  File "C:\Users\ARX\anaconda3\envs\agent\lib\site-packages\tensordict\nn\common.py", line 297, in wrapper
    return func(_self, tensordict, *args, **kwargs)
  File "C:\Users\ARX\anaconda3\envs\agent\lib\site-packages\torchrl\objectives\value\advantages.py", line 1357, in forward
    adv, value_target = vec_generalized_advantage_estimate(
  File "C:\Users\ARX\anaconda3\envs\agent\lib\site-packages\torchrl\objectives\value\functional.py", line 89, in transposed_fun
    out = fun(*args, **kwargs)
  File "C:\Users\ARX\anaconda3\envs\agent\lib\site-packages\torchrl\objectives\value\functional.py", line 315, in vec_generalized_advantage_estimate
    return _fast_vec_gae(
  File "C:\Users\ARX\anaconda3\envs\agent\lib\site-packages\torchrl\objectives\value\functional.py", line 250, in _fast_vec_gae
    advantage = _custom_conv1d(td0_flat.unsqueeze(1), gammalmbdas)
  File "C:\Users\ARX\anaconda3\envs\agent\lib\site-packages\torchrl\objectives\value\utils.py", line 78, in _custom_conv1d
    val_pad = torch.nn.functional.pad(tensor, [0, filter.shape[-2] - 1])
IndexError: tuple index out of range

I found out that in torchrl\objectives\value\functional.py, and inside the function vec_generalized_advantage_estimate line 307, value variable is vector of zeros (1d) with length of the sample batch size, but without connecting the actor_module it's the truth matrix of multiplied gammas and lambdas (with one column), and I found out that in the buffer of the advantage module , when the collector uses the actor module, it resets gamma and lmbda of the buffer to 0.0 (Inside the training loop print("Gamma : ", self.advantage_module.get_buffer("gamma")) outputs tensor(0.) ).

So I added these tow lines after the loss module definition:

self.advantage_module.register_buffer("gamma", torch.tensor(self.advantage_gamma))
self.advantage_module.register_buffer("lmbda", torch.tensor(self.advantage_lmbda))

By adding these tow lines of code, the previous error vanished, but a new issue appeared:

loss_vals["loss_objective"] + loss_vals["loss_critic"]
  File "C:\Users\ARX\anaconda3\envs\agent\lib\site-packages\tensordict\base.py", line 335, in __getitem__
    result = self._get_tuple(idx_unravel, NO_DEFAULT)
  File "C:\Users\ARX\anaconda3\envs\agent\lib\site-packages\tensordict\_td.py", line 2399, in _get_tuple
    first = self._get_str(key[0], default)
  File "C:\Users\ARX\anaconda3\envs\agent\lib\site-packages\tensordict\_td.py", line 2395, in _get_str
    return self._default_get(first_key, default)
  File "C:\Users\ARX\anaconda3\envs\agent\lib\site-packages\tensordict\base.py", line 4503, in _default_get
    raise KeyError(
KeyError: 'key "loss_critic" not found in TensorDict with keys [\'ESS\', \'clip_fraction\', \'entropy\', \'kl_approx\', \'loss_entropy\', \'loss_objective\']'

That clearly implies that the key "loss_critic" does not exist in the sample tensordict object (but before I connect the actor module to the collector it calculates it properly).

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions