-
Notifications
You must be signed in to change notification settings - Fork 32
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
PPO dual clip #37
base: master
Are you sure you want to change the base?
PPO dual clip #37
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,28 +10,29 @@ | |
from cherry import debug | ||
|
||
|
||
def policy_loss(new_log_probs, old_log_probs, advantages, clip=0.1): | ||
def policy_loss(new_log_probs, old_log_probs, advantages, clip=0.1, dual_clip=None): | ||
""" | ||
[[Source]](https://github.com/seba-1511/cherry/blob/master/cherry/algorithms/ppo.py) | ||
|
||
**Description** | ||
|
||
The clipped policy loss of Proximal Policy Optimization. | ||
The dual clipped policy loss of Dual-Clip Proximal Policy Optimization. | ||
|
||
**References** | ||
|
||
1. Schulman et al. 2017. “Proximal Policy Optimization Algorithms.” arXiv [cs.LG]. | ||
1. Deheng Ye et al. 2020 . “ Mastering Complex Control in MOBA Games with Deep Reinforcement Learning.” arXiv:1912.09729 . | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please keep the original reference -- you can add the new one as well. |
||
|
||
**Arguments** | ||
|
||
* **new_log_probs** (tensor) - The log-density of actions from the target policy. | ||
* **old_log_probs** (tensor) - The log-density of actions from the behaviour policy. | ||
* **advantages** (tensor) - Advantage of the actions. | ||
* **clip** (float, *optional*, default=0.1) - The clipping coefficient. | ||
* **dual_clip** (float, *optional*, default=None) - The dual-clipping coefficient. | ||
|
||
**Returns** | ||
|
||
* (tensor) - The clipped policy loss for the given arguments. | ||
* (tensor) - The dual-clipped policy loss for the given arguments. | ||
|
||
**Example** | ||
|
||
|
@@ -44,26 +45,38 @@ def policy_loss(new_log_probs, old_log_probs, advantages, clip=0.1): | |
next_state_value) | ||
new_densities = policy(replay.state()) | ||
new_logprobs = new_densities.log_prob(replay.action()) | ||
loss = policy_loss(new_logprobs, | ||
loss = loss_dual_clip(new_logprobs, | ||
replay.logprob().detach(), | ||
advantage.detach(), | ||
clip=0.2) | ||
clip=0.2, | ||
dual_clip=2) | ||
~~~ | ||
""" | ||
msg = 'new_log_probs, old_log_probs and advantages must have equal size.' | ||
assert new_log_probs.size() == old_log_probs.size() == advantages.size(),\ | ||
msg | ||
msg = "new_log_probs, old_log_probs and advantages must have equal size." | ||
assert new_log_probs.size() == old_log_probs.size() == advantages.size(), msg | ||
if debug.IS_DEBUGGING: | ||
if old_log_probs.requires_grad: | ||
debug.logger.warning('PPO:policy_loss: old_log_probs.requires_grad is True.') | ||
debug.logger.warning( | ||
"PPO:policy_loss: old_log_probs.requires_grad is True." | ||
) | ||
if advantages.requires_grad: | ||
debug.logger.warning('PPO:policy_loss: advantages.requires_grad is True.') | ||
debug.logger.warning("PPO:policy_loss: advantages.requires_grad is True.") | ||
if not new_log_probs.requires_grad: | ||
debug.logger.warning('PPO:policy_loss: new_log_probs.requires_grad is False.') | ||
debug.logger.warning( | ||
"PPO:policy_loss: new_log_probs.requires_grad is False." | ||
) | ||
Comment on lines
-53
to
+67
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Those lines shouldn't be modified. |
||
ratios = th.exp(new_log_probs - old_log_probs) | ||
obj = ratios * advantages | ||
obj_clip = ratios.clamp(1.0 - clip, 1.0 + clip) * advantages | ||
return - th.min(obj, obj_clip).mean() | ||
if dual_clip is not None: | ||
obj_dual_clip = dual_clip * advantages | ||
|
||
return -( | ||
(th.max(th.min(obj, obj_clip), obj_dual_clip)[advantages < 0]).mean() | ||
+ (th.min(obj, obj_clip)[advantages > 0]).mean() | ||
) | ||
|
||
return -th.min(obj, obj_clip).mean() | ||
|
||
|
||
def state_value_loss(new_values, old_values, rewards, clip=0.1): | ||
|
@@ -99,16 +112,20 @@ def state_value_loss(new_values, old_values, rewards, clip=0.1): | |
clip=0.2) | ||
~~~ | ||
""" | ||
msg = 'new_values, old_values, and rewards must have equal size.' | ||
msg = "new_values, old_values, and rewards must have equal size." | ||
assert new_values.size() == old_values.size() == rewards.size(), msg | ||
if debug.IS_DEBUGGING: | ||
if old_values.requires_grad: | ||
debug.logger.warning('PPO:state_value_loss: old_values.requires_grad is True.') | ||
debug.logger.warning( | ||
"PPO:state_value_loss: old_values.requires_grad is True." | ||
) | ||
if rewards.requires_grad: | ||
debug.logger.warning('PPO:state_value_loss: rewards.requires_grad is True.') | ||
debug.logger.warning("PPO:state_value_loss: rewards.requires_grad is True.") | ||
if not new_values.requires_grad: | ||
debug.logger.warning('PPO:state_value_loss: new_values.requires_grad is False.') | ||
loss = (rewards - new_values)**2 | ||
debug.logger.warning( | ||
"PPO:state_value_loss: new_values.requires_grad is False." | ||
) | ||
loss = (rewards - new_values) ** 2 | ||
clipped_values = old_values + (new_values - old_values).clamp(-clip, clip) | ||
clipped_loss = (rewards - clipped_values)**2 | ||
clipped_loss = (rewards - clipped_values) ** 2 | ||
Comment on lines
+115
to
+130
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Those lines shouldn't be modified. |
||
return 0.5 * th.max(loss, clipped_loss).mean() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add to the original description, while keeping the original message.