Skip to content

Commit

Permalink
Change head entropy logged name
Browse files Browse the repository at this point in the history
  • Loading branch information
Louis Faury committed Jan 20, 2025
1 parent 2a68c60 commit 71bd4bd
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
5 changes: 4 additions & 1 deletion test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -8409,7 +8409,6 @@ def test_ppo_composite_no_aggregate(
if isinstance(loss_fn, KLPENPPOLoss):
kl = loss.pop("kl_approx")
assert (kl != 0).any()

loss_critic = loss["loss_critic"]
loss_objective = loss["loss_objective"] + loss.get("loss_entropy", 0.0)
loss_critic.backward(retain_graph=True)
Expand Down Expand Up @@ -8637,12 +8636,16 @@ def test_ppo_shared_seq(
)

loss = loss_fn(td).exclude("entropy")
if composite_action_dist:
loss = loss.exclude("action-action1_entropy")

sum(val for key, val in loss.items() if key.startswith("loss_")).backward()
grad = TensorDict(dict(model.named_parameters()), []).apply(
lambda x: x.grad.clone()
)
loss2 = loss_fn2(td).exclude("entropy")
if composite_action_dist:
loss2 = loss2.exclude("action-action1_entropy")

model.zero_grad()
sum(val for key, val in loss2.items() if key.startswith("loss_")).backward()
Expand Down
9 changes: 3 additions & 6 deletions torchrl/objectives/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,8 +687,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
for head_key, head_entropy in entropy.items(
include_nested=True, leaves_only=True
):
head_prefix = head_key[0] if len(head_key) == 1 else head_key[-2]
td_out.set(f"{head_prefix}_entropy", head_entropy.detach().mean())
td_out.set("-".join(head_key), head_entropy.detach().mean())
entropy = _sum_td_features(entropy)
td_out.set("entropy", entropy.detach().mean()) # for logging
td_out.set("loss_entropy", -self.entropy_coef * entropy)
Expand Down Expand Up @@ -991,8 +990,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
for head_key, head_entropy in entropy.items(
include_nested=True, leaves_only=True
):
head_prefix = head_key[0] if len(head_key) == 1 else head_key[-2]
td_out.set(f"{head_prefix}_entropy", head_entropy.detach().mean())
td_out.set("-".join(head_key), head_entropy.detach().mean())
entropy = _sum_td_features(entropy)
td_out.set("entropy", entropy.detach().mean()) # for logging
td_out.set("loss_entropy", -self.entropy_coef * entropy)
Expand Down Expand Up @@ -1308,8 +1306,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict:
for head_key, head_entropy in entropy.items(
include_nested=True, leaves_only=True
):
head_prefix = head_key[0] if len(head_key) == 1 else head_key[-2]
td_out.set(f"{head_prefix}_entropy", head_entropy.detach().mean())
td_out.set("-".join(head_key), head_entropy.detach().mean())
entropy = _sum_td_features(entropy)
td_out.set("entropy", entropy.detach().mean()) # for logging
td_out.set("loss_entropy", -self.entropy_coef * entropy)
Expand Down

0 comments on commit 71bd4bd

Please sign in to comment.