Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Prioritized experience replay #1622
base: master
Are you sure you want to change the base?
Prioritized experience replay #1622
Changes from 23 commits
c607585
57f1192
2b9df33
31dc46c
1a32377
ccf7dc3
aee1d30
c51b173
18c9d28
840dde2
dcfbf88
fb33732
f984e5c
5edf8bf
2f76038
42f2f4a
007105f
cc37cba
b60ef03
ec272b9
a043cfd
f6accf9
b21ef33
f57444a
be00231
4390ec7
bee9cbe
fb1a9f7
150b09a
5c0c79d
148e4aa
File filter
Filter by extension
Conversations
Jump to
There are no files selected for viewing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How could we add proper support for multiple envs? Is there any idea? Does the random line below could work?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sure yet, the random line below might work but we need to check if it won't affect performance first.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
araffin/sbx@b5ce091 should be better, see araffin/sbx#50
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@AlexPasqua Ideally, we'd like to be able to associate it with all off-policy algo's without adaptation, but I don't see a simple way of doing it at this stage.
Also related, we had discussed not modifying DQN: Stable-Baselines-Team/stable-baselines3-contrib#127 (comment)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm interested in this PR. Since every algo-specific
train
method includes areplay_buffer.sample
line, couldn't we just additionally add areplay_buffer.update
line? The update function could take in the current and target q values whenever a value function is present or maybe even all the local variables. It would do nothing for the vanilla replay buffer. Would this be an acceptable modification?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your comment!
How do you handle the loss in your proposal?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we want this to work for general off-policy algorithms, we could update the ReplayBufferSample-like classes to additionally include an
importance_sampling_weight
attribute which would be updated from thereplay_buffer.update
method.Then I see two ways to handle the loss under this interface:
Obviously the downside of this is that it requires hand engineering for the different types of loss functions or priority metrics.
train
methods "td-error" centric in the sense that we always computetd_error = importance_sampling_weight * th.abs(current_q_values - target_q_values)
first, then the lossloss = loss_fn(td_error)
. The downsides of this approach is that we cant use the pytorch api for computing the loss, and would have to write functions for those.Either approach requires computing a
td_error
variable which unfortunately requires somewhat intrusive code changes. What do you think?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe to make things clearer: my plan is not to have PER for all algorithms, mainly for two reasons:
If the users really want PER in other algo, they would take inspiration from a reference implementation in SB3 and integrate it (the same way we don't provide maskable + recurrent PPO at the same time).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"just" yes, I would be happy to receive such PR =)
the main thing is to benchmark the implementation and reproduce the published results.
This PR is also still open because I was not satisfied by the result of DQN + PER (I couldn't see significant different with respect to DQN).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One thing I had in mind was to implement CNN for SBX (https://github.com/araffin/sbx) in order to iterate faster and check the PER, but I had no time to do so until now...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why don't we implement the toy environment from figure 1 of https://arxiv.org/pdf/1511.05952 as the PER benchmark? It would be a simpler initial check for correctness than the Atari environments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The toy environment can be a start for fast iteration and debugging, but what we learned in the past is that subtle bugs only show up when doing more complex task (see #48 and #47 where we found bugs like PyTorch and TF RMSProp are not the same)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, will definitely work towards it!