Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Are you interested in PRs for improvements in performance of PPO LSTM script? #276

Open
thomasbbrunner opened this issue Sep 19, 2022 · 3 comments

Comments

@thomasbbrunner
Copy link

Problem Description

The current PPO with LSTM script ppo_atari_lstm.py uses sequential stepping through the LSTM, i.e. each step of the sequence in processed individually:

for h, d in zip(hidden, done):
    h, lstm_state = self.lstm(
        h.unsqueeze(0),
        (
            (1.0 - d).view(1, -1, 1) * lstm_state[0],
            (1.0 - d).view(1, -1, 1) * lstm_state[1],
        ),
    )
    new_hidden += [h]

This method is very slow compared to sending the entire sequence of observations to the LSTM:

h, lstm_state = self.lstm(hidden, lstm_state)

This usually cannot be done in RL, as we have to reset the hidden states when an episode ends.

Other implementations of PPO use a trick, which is to split a sequence containing several trajectories into several sequences that contain only one trajectory. This is accomplished by splitting the input sequence everywhere where there's a done and padding the rest of the sequence. This can be visualized as:

Original sequences: [ [a1, a2, a3, a4 | a5, a6],
                      [b1, b2 | b3, b4, b5 | b6] ]

Split sequences:[ [a1, a2, a3, a4],
                  [a5, a6, 0, 0],
                  [b1, b2, 0, 0],
                  [b3, b4, b5, 0],
                  [b6, 0, 0, 0] ]

With this trick it is possible to make a single call to the LSTM to process multiple sequences and batches.

Proposal

I implemented a version of the script that uses this trick to process sequences with one call. In my setup, it led to a 4x improvement in training speed. However, it comes with a higher memory usage (about 2x in my setup). The final performance of the policy is similar to the original script.

Would you be interested in adding this script to the repo? Should I make a PR to create a new file using this trick?

@vwxyzjn
Copy link
Owner

vwxyzjn commented Sep 19, 2022

Thanks @thomasbbrunner. This looks like a really cool trick. CC @araffin. This idea seems to be related to Stable-Baselines-Team/stable-baselines3-contrib#53 (comment).

Would you be interested in comparing this technique in JAX? We are gradually adopting JAX, which is much faster (see #227 (comment)). The current PPO + LSTM implementation is slow due to the python for loop, but we might be able to speed it up considerably using JIT and JAX.

In that sense, I would be interested to see the performance difference between doing split sequences and original sequences using JAX — if there is no significant difference when using JIT, then it might not be worth doing this technique.

@thomasbbrunner
Copy link
Author

Thanks for the interest!

I am not very familiar with JAX. Currently, I don't have the time to look into this, but I am interested in it and I will make time for it in the near future.

@vwxyzjn
Copy link
Owner

vwxyzjn commented Oct 24, 2022

Sounds good thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants