Skip to content

Commit

Permalink
[Feature] Log each entropy for composite distributions in PPO (#2707)
Browse files Browse the repository at this point in the history
Co-authored-by: Louis Faury <[email protected]>
  • Loading branch information
louisfaury and Louis Faury authored Jan 24, 2025
1 parent d4e4019 commit 319bb68
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 13 deletions.
5 changes: 4 additions & 1 deletion test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -8409,7 +8409,6 @@ def test_ppo_composite_no_aggregate(
if isinstance(loss_fn, KLPENPPOLoss):
kl = loss.pop("kl_approx")
assert (kl != 0).any()

loss_critic = loss["loss_critic"]
loss_objective = loss["loss_objective"] + loss.get("loss_entropy", 0.0)
loss_critic.backward(retain_graph=True)
Expand Down Expand Up @@ -8637,12 +8636,16 @@ def test_ppo_shared_seq(
)

loss = loss_fn(td).exclude("entropy")
if composite_action_dist:
loss = loss.exclude("composite_entropy")

sum(val for key, val in loss.items() if key.startswith("loss_")).backward()
grad = TensorDict(dict(model.named_parameters()), []).apply(
lambda x: x.grad.clone()
)
loss2 = loss_fn2(td).exclude("entropy")
if composite_action_dist:
loss2 = loss2.exclude("composite_entropy")

model.zero_grad()
sum(val for key, val in loss2.items() if key.startswith("loss_")).backward()
Expand Down
34 changes: 22 additions & 12 deletions torchrl/objectives/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,7 @@ def _forward_value_estimator_keys(self, **kwargs) -> None:
def reset(self) -> None:
pass

def get_entropy_bonus(self, dist: d.Distribution) -> torch.Tensor:
def _get_entropy(self, dist: d.Distribution) -> torch.Tensor | TensorDict:
try:
entropy = dist.entropy()
except NotImplementedError:
Expand All @@ -513,13 +513,11 @@ def get_entropy_bonus(self, dist: d.Distribution) -> torch.Tensor:
log_prob = log_prob.select(*self.tensor_keys.sample_log_prob)

entropy = -log_prob.mean(0)
if is_tensor_collection(entropy):
entropy = _sum_td_features(entropy)
return entropy.unsqueeze(-1)

def _log_weight(
self, tensordict: TensorDictBase
) -> Tuple[torch.Tensor, d.Distribution]:
) -> Tuple[torch.Tensor, d.Distribution, torch.Tensor]:

with self.actor_network_params.to_module(
self.actor_network
Expand Down Expand Up @@ -681,10 +679,14 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
log_weight = log_weight.view(advantage.shape)
neg_loss = log_weight.exp() * advantage
td_out = TensorDict({"loss_objective": -neg_loss}, batch_size=[])
td_out.set("kl_approx", kl_approx.detach().mean()) # for logging
if self.entropy_bonus:
entropy = self.get_entropy_bonus(dist)
entropy = self._get_entropy(dist)
if is_tensor_collection(entropy):
# Reports the entropy of each action head.
td_out.set("composite_entropy", entropy.detach())
entropy = _sum_td_features(entropy)
td_out.set("entropy", entropy.detach().mean()) # for logging
td_out.set("kl_approx", kl_approx.detach().mean()) # for logging
td_out.set("loss_entropy", -self.entropy_coef * entropy)
if self.critic_coef is not None:
loss_critic, value_clip_fraction = self.loss_critic(tensordict)
Expand Down Expand Up @@ -956,8 +958,8 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
# ESS for logging
with torch.no_grad():
# In theory, ESS should be computed on particles sampled from the same source. Here we sample according
# to different, unrelated trajectories, which is not standard. Still it can give a idea of the dispersion
# of the weights.
# to different, unrelated trajectories, which is not standard. Still, it can give an idea of the weights'
# dispersion.
lw = log_weight.squeeze()
if not isinstance(lw, torch.Tensor):
lw = _sum_td_features(lw)
Expand All @@ -976,11 +978,15 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
gain = _sum_td_features(gain)
td_out = TensorDict({"loss_objective": -gain}, batch_size=[])
td_out.set("clip_fraction", clip_fraction)
td_out.set("kl_approx", kl_approx.detach().mean()) # for logging

if self.entropy_bonus:
entropy = self.get_entropy_bonus(dist)
entropy = self._get_entropy(dist)
if is_tensor_collection(entropy):
# Reports the entropy of each action head.
td_out.set("composite_entropy", entropy.detach())
entropy = _sum_td_features(entropy)
td_out.set("entropy", entropy.detach().mean()) # for logging
td_out.set("kl_approx", kl_approx.detach().mean()) # for logging
td_out.set("loss_entropy", -self.entropy_coef * entropy)
if self.critic_coef is not None:
loss_critic, value_clip_fraction = self.loss_critic(tensordict)
Expand Down Expand Up @@ -1282,14 +1288,18 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict:
{
"loss_objective": -neg_loss,
"kl": kl.detach(),
"kl_approx": kl_approx.detach().mean(),
},
batch_size=[],
)

if self.entropy_bonus:
entropy = self.get_entropy_bonus(dist)
entropy = self._get_entropy(dist)
if is_tensor_collection(entropy):
# Reports the entropy of each action head.
td_out.set("composite_entropy", entropy.detach())
entropy = _sum_td_features(entropy)
td_out.set("entropy", entropy.detach().mean()) # for logging
td_out.set("kl_approx", kl_approx.detach().mean()) # for logging
td_out.set("loss_entropy", -self.entropy_coef * entropy)
if self.critic_coef is not None:
loss_critic, value_clip_fraction = self.loss_critic(tensordict_copy)
Expand Down

0 comments on commit 319bb68

Please sign in to comment.