Description
Environment
OS: Windows 11
Python : CPython 3.10.14
Torchrl Version : 0.5.0
PyTorch Version : 2.4.1+cu124
Gym Environment: A custom subclass of EnvBase (from torchrl.envs)
The project I'm working on is relatively complex, so I only mention parts of code that I know are related to the bug that I mention below.
Here's the definition of actor, value (critic), advantage, and loss module.
import torch
from torchrl.modules import ProbabilisticActor, ValueOperator, OneHotCategorical
from torchrl.objectives import ClipPPOLoss
from torchrl.objectives.value import GAE
self.action_spec = DiscreteTensorSpec(3, dtype=torch.int8)
# Actor
_actor = TensorDictModule(self.agent, in_keys=self.agent.in_keys,
out_keys=self.agent.out_keys).to(self.agent.device)
self.actor_module = ProbabilisticActor(
_actor,
in_keys=self.agent.out_keys,
spec=self.action_spec,
distribution_class=OneHotCategorical,
return_log_prob=True,
)
# Critic
self.value_net = MyValueNetwork(device=agent.device)
self.value_module = ValueOperator(
self.value_net, in_keys=self.value_net.in_keys,
)
# Advantage
self.advantage_module = GAE(value_network=self.value_module,
gamma=self.advantage_gamma,
lmbda=self.advantage_lmbda,
differentiable=True,
average_gae=True,
device=self.value_module.device,
)
# Loss
entropy_eps = 0.001
self.loss_module = ClipPPOLoss(
actor_network=self.actor_module,
critic_network=self.value_module,
clip_epsilon=(0.2, ),
entropy_bonus=bool(entropy_eps),
entropy_coef=entropy_eps,
critic_coef=1.0,
loss_critic_type="smooth_l1",
)
Training loop
My training loop catches the batched data from a MultiSyncDataCollector
, and adds it to a replay buffer with a LazyTensorStorage
storage, and after that it samples and passes the sample to the _optimize_policy function:
def _optimize_sample(self, sample: TensorDict):
self.actor_optimizer.zero_grad()
self.critic_optimizer.zero_grad()
if not self.value_net.weights_initialized:
self.value_net(sample["observation"])
self.value_net.weights_initialized = True
self.actor_module(sample)
sample["sample_log_prob"] = sample["sample_log_prob"].detach()
self.advantage_module(sample)
loss_vals = self.loss_module(sample)
total_loss = loss_vals["loss_entropy"] + \
loss_vals["loss_objective"] + loss_vals["loss_critic"]
total_loss.backward()
torch.nn.utils.clip_grad_norm_(
self.actor_module.parameters(), max_norm=0.1)
torch.nn.utils.clip_grad_norm_(
self.value_module.parameters(), max_norm=1.0)
self.actor_optimizer.step()
self.critic_optimizer.step()
return loss_vals
In the code above I got the error below when it called self.actor_module(sample)
:
loss_vals = self.loss_module(sample)
File "C:\Users\ARX\anaconda3\envs\agent\lib\site-packages\torch\nn\modules\module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "C:\Users\ARX\anaconda3\envs\agent\lib\site-packages\torch\nn\modules\module.py", line 1603, in _call_impl
result = forward_call(*args, **kwargs)
File "C:\Users\ARX\anaconda3\envs\agent\lib\site-packages\torchrl\objectives\common.py", line 39, in new_forward
return func(self, *args, **kwargs)
File "C:\Users\ARX\anaconda3\envs\agent\lib\site-packages\tensordict\nn\common.py", line 297, in wrapper
return func(_self, tensordict, *args, **kwargs)
File "C:\Users\ARX\anaconda3\envs\agent\lib\site-packages\torchrl\objectives\ppo.py", line 817, in forward
log_weight, dist, kl_approx = self._log_weight(tensordict)
File "C:\Users\ARX\anaconda3\envs\agent\lib\site-packages\torchrl\objectives\ppo.py", line 473, in _log_weight
raise RuntimeError("tensordict prev_log_prob requires grad.")
RuntimeError: tensordict prev_log_prob requires grad.
So I added sample["sample_log_prob"] = sample["sample_log_prob"].detach()
to detach sample_log_prob
from the computation graph. and the issue was solved.
At this stage the model seems to converge, as objective and critic loss is minimizing:
Figure 1 - Objective/Policy loss (Exponentially moving average interval 100):
The main issue
At this point apparently, everything is ok, but the main issue occurs when I connect the actor (policy) module to the collector, to collect data based on the current policy (not a random choice of actions):
train_kwargs["policy_device"] = self.agent.device
train_kwargs["policy"] = self.actor_module
my_collector = MultiSyncDataCollector(**train_kwargs)
And when I run it, I get the error below (thrown inside self.advantage_module(sample)
):
self.advantage_module(sample)
File "C:\Users\ARX\anaconda3\envs\agent\lib\site-packages\torch\nn\modules\module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "C:\Users\ARX\anaconda3\envs\agent\lib\site-packages\torch\nn\modules\module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "C:\Users\ARX\anaconda3\envs\agent\lib\site-packages\torchrl\objectives\value\advantages.py", line 68, in new_func
return fun(self, *args, **kwargs)
File "C:\Users\ARX\anaconda3\envs\agent\lib\site-packages\torchrl\objectives\value\advantages.py", line 57, in new_fun
return fun(self, *args, **kwargs)
File "C:\Users\ARX\anaconda3\envs\agent\lib\site-packages\tensordict\nn\common.py", line 297, in wrapper
return func(_self, tensordict, *args, **kwargs)
File "C:\Users\ARX\anaconda3\envs\agent\lib\site-packages\torchrl\objectives\value\advantages.py", line 1357, in forward
adv, value_target = vec_generalized_advantage_estimate(
File "C:\Users\ARX\anaconda3\envs\agent\lib\site-packages\torchrl\objectives\value\functional.py", line 89, in transposed_fun
out = fun(*args, **kwargs)
File "C:\Users\ARX\anaconda3\envs\agent\lib\site-packages\torchrl\objectives\value\functional.py", line 315, in vec_generalized_advantage_estimate
return _fast_vec_gae(
File "C:\Users\ARX\anaconda3\envs\agent\lib\site-packages\torchrl\objectives\value\functional.py", line 250, in _fast_vec_gae
advantage = _custom_conv1d(td0_flat.unsqueeze(1), gammalmbdas)
File "C:\Users\ARX\anaconda3\envs\agent\lib\site-packages\torchrl\objectives\value\utils.py", line 78, in _custom_conv1d
val_pad = torch.nn.functional.pad(tensor, [0, filter.shape[-2] - 1])
IndexError: tuple index out of range
I found out that in torchrl\objectives\value\functional.py
, and inside the function vec_generalized_advantage_estimate
line 307, value
variable is vector of zeros (1d) with length of the sample batch size, but without connecting the actor_module it's the truth matrix of multiplied gammas and lambdas (with one column), and I found out that in the buffer of the advantage module , when the collector uses the actor module, it resets gamma
and lmbda
of the buffer to 0.0 (Inside the training loop print("Gamma : ", self.advantage_module.get_buffer("gamma"))
outputs tensor(0.)
).
So I added these tow lines after the loss module definition:
self.advantage_module.register_buffer("gamma", torch.tensor(self.advantage_gamma))
self.advantage_module.register_buffer("lmbda", torch.tensor(self.advantage_lmbda))
By adding these tow lines of code, the previous error vanished, but a new issue appeared:
loss_vals["loss_objective"] + loss_vals["loss_critic"]
File "C:\Users\ARX\anaconda3\envs\agent\lib\site-packages\tensordict\base.py", line 335, in __getitem__
result = self._get_tuple(idx_unravel, NO_DEFAULT)
File "C:\Users\ARX\anaconda3\envs\agent\lib\site-packages\tensordict\_td.py", line 2399, in _get_tuple
first = self._get_str(key[0], default)
File "C:\Users\ARX\anaconda3\envs\agent\lib\site-packages\tensordict\_td.py", line 2395, in _get_str
return self._default_get(first_key, default)
File "C:\Users\ARX\anaconda3\envs\agent\lib\site-packages\tensordict\base.py", line 4503, in _default_get
raise KeyError(
KeyError: 'key "loss_critic" not found in TensorDict with keys [\'ESS\', \'clip_fraction\', \'entropy\', \'kl_approx\', \'loss_entropy\', \'loss_objective\']'
That clearly implies that the key "loss_critic" does not exist in the sample tensordict object (but before I connect the actor module to the collector it calculates it properly).