-
Notifications
You must be signed in to change notification settings - Fork 177
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Recurrent PPO #53
Recurrent PPO #53
Conversation
Actually not, see the shape in the buffer https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/blob/feat/ppo-lstm/sb3_contrib/common/recurrent/buffers.py#L164 |
for the time being, is there any way we can feedback the older data back to the trainer as input to mimic a crude version of LSTM? any sample code to do that? |
Hello, |
when LSTM for A2C? |
a2c is a special case of ppo ;) (cc @vwxyzjn ) |
@henrydeclety see https://github.com/vwxyzjn/a2c_is_a_special_case_of_ppo. We have a paper coming out soon... |
The preprint of the paper is out at https://arxiv.org/abs/2205.09123 @henrydeclety :) |
I’ll give it a try |
How could I configure the maximum sequence length for the LSTM? |
@EloyAnguiano As far as I could tell from the code, the implementation in SB3 does not have a sequence length, but saves the hidden state between steps of your environment and then uses it as input. So the maximum sequence length for the lstm would be the number of steps (n_steps) before you update your policy. This way you only need to compute each input once, instead of refeeding it every new step. |
Description
Experimental version of PPO with LSTM policy.
Current status: usable but not polished, see #53 (comment)
Missing:
Known issue: if the model was train on GPU and tested on CPU, a warning will be issued because it cannot unpickle the lstm initial states. This is ok as they will be reset anyway in
setup_model()
and it doesn't affect prediction.Context
closes recurrent policy implementation in ppo [feature-request] DLR-RM/stable-baselines3#18
Types of changes
Checklist:
make format
(required)make check-codestyle
andmake lint
(required)make pytest
andmake type
both pass. (required)Note: we are using a maximum length of 127 characters per line