-
Notifications
You must be signed in to change notification settings - Fork 331
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Log each entropy for composite distributions in PPO #2707
[Feature] Log each entropy for composite distributions in PPO #2707
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/2707
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
torchrl/objectives/ppo.py
Outdated
for head_key, head_entropy in entropy.items( | ||
include_nested=True, leaves_only=True | ||
): | ||
td_out.set("-".join(head_key), head_entropy.detach().mean()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Choosing under which key to log the individual factor's entropy was a bit of a headache. The way I personally use CompositeDistribution
yields tensordict that look like:
action: {
head_1: {
action: ...
entropy: ...
}
head_2 {
action: ...
entropy: ...
}
}
which means that using head_key[-1]
to log each entropy is not really a viable solution (all the factor entropies will be logged under the same name, entropy
). I'm not sure how to get a one-size-fits-all here, and happy for suggestions. The current solution ensures that there is no collision, at the price of having very verbose keys.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wouldn't the most generic solution just be to log the entropy TD as it comes?
Why do we need to rename it?
BTW it seems to me that what you're doing here amends to
tensordict.flatten_keys("-").detact().mean()
Nit: this isn't collision-safe I think (but flatten_keys will tell you if there are any collision):
eg ("key-one", "entropy")
and ("key", "one", "entropy")
will collide
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also - are we 100% sure all keys are nested? I think so (that's how CompositeDist
works) but maybe we could just put a safeguard check here to make sure an error is raised if that assumption is violated (eg, users have their own dist class that returns {"entropy", ("nested", "entropy")}
keys).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I followed your recommendation: I added a composite_entropy
key to the loss td. Two remarks:
- The composite entropy is not logged under
entropy
to avoid BC (users currently expect a Tensor), - I did not
detach()
the composite entropy; this would allow the user to compute a custom entropy bonus when using a composite entropy (e.g. not the same penalty per head).
Wdyt ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not detaching could be slightly bc-breaking: what if I do loss_tensordict.sum(reduce=True).backward()
?
Previously, this was giving the right result, now it would also backprop through the entropy. Usually metadata in loss outputs is guaranteed to be non-differentiable, that would be a one off.
But I understand that it could be useful...
We could add a kwarg in the constructor (which would become a bit overloaded!)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right. I'll detach for now, let's revisit when there is a need/ask for it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, thanks for raising it!
Thinking about it, a way of doing this could be to pass the entropy coefficient as a tensordict and do
(td_coef * td_entropy).sum(reduce=True)
Idea for a follow-up PR ;)
71bd4bd
to
06c3d94
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, just a couple of comments to address before we merge it
LMK what you think would be the best way to log the entropies in the output data structure, I think flattening may be a bit surprising
torchrl/objectives/ppo.py
Outdated
for head_key, head_entropy in entropy.items( | ||
include_nested=True, leaves_only=True | ||
): | ||
td_out.set("-".join(head_key), head_entropy.detach().mean()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wouldn't the most generic solution just be to log the entropy TD as it comes?
Why do we need to rename it?
BTW it seems to me that what you're doing here amends to
tensordict.flatten_keys("-").detact().mean()
Nit: this isn't collision-safe I think (but flatten_keys will tell you if there are any collision):
eg ("key-one", "entropy")
and ("key", "one", "entropy")
will collide
torchrl/objectives/ppo.py
Outdated
for head_key, head_entropy in entropy.items( | ||
include_nested=True, leaves_only=True | ||
): | ||
td_out.set("-".join(head_key), head_entropy.detach().mean()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also - are we 100% sure all keys are nested? I think so (that's how CompositeDist
works) but maybe we could just put a safeguard check here to make sure an error is raised if that assumption is violated (eg, users have their own dist class that returns {"entropy", ("nested", "entropy")}
keys).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM thanks!
Description
This PR enables PPO to log the entropy of each individual head of a composite policy separately.
Concretely, for a composite distribution with, say, a nested discrete and continuous head, the
td_out
is augmented with some detached values.Motivation and Context
This is an extremely useful debugging tool when training composite policies.
Types of changes
What types of changes does your code introduce? Remove all that do not apply:
Checklist
Go over all the following points, and put an
x
in all the boxes that apply.If you are unsure about any of these, don't hesitate to ask. We are here to help!