You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The current PPO with LSTM script ppo_atari_lstm.py uses sequential stepping through the LSTM, i.e. each step of the sequence in processed individually:
This method is very slow compared to sending the entire sequence of observations to the LSTM:
h, lstm_state=self.lstm(hidden, lstm_state)
This usually cannot be done in RL, as we have to reset the hidden states when an episode ends.
Other implementations of PPO use a trick, which is to split a sequence containing several trajectories into several sequences that contain only one trajectory. This is accomplished by splitting the input sequence everywhere where there's a done and padding the rest of the sequence. This can be visualized as:
With this trick it is possible to make a single call to the LSTM to process multiple sequences and batches.
Proposal
I implemented a version of the script that uses this trick to process sequences with one call. In my setup, it led to a 4x improvement in training speed. However, it comes with a higher memory usage (about 2x in my setup). The final performance of the policy is similar to the original script.
Would you be interested in adding this script to the repo? Should I make a PR to create a new file using this trick?
The text was updated successfully, but these errors were encountered:
Would you be interested in comparing this technique in JAX? We are gradually adopting JAX, which is much faster (see #227 (comment)). The current PPO + LSTM implementation is slow due to the python for loop, but we might be able to speed it up considerably using JIT and JAX.
In that sense, I would be interested to see the performance difference between doing split sequences and original sequences using JAX — if there is no significant difference when using JIT, then it might not be worth doing this technique.
I am not very familiar with JAX. Currently, I don't have the time to look into this, but I am interested in it and I will make time for it in the near future.
Problem Description
The current PPO with LSTM script ppo_atari_lstm.py uses sequential stepping through the LSTM, i.e. each step of the sequence in processed individually:
This method is very slow compared to sending the entire sequence of observations to the LSTM:
This usually cannot be done in RL, as we have to reset the hidden states when an episode ends.
Other implementations of PPO use a trick, which is to split a sequence containing several trajectories into several sequences that contain only one trajectory. This is accomplished by splitting the input sequence everywhere where there's a
done
and padding the rest of the sequence. This can be visualized as:With this trick it is possible to make a single call to the LSTM to process multiple sequences and batches.
Proposal
I implemented a version of the script that uses this trick to process sequences with one call. In my setup, it led to a 4x improvement in training speed. However, it comes with a higher memory usage (about 2x in my setup). The final performance of the policy is similar to the original script.
Would you be interested in adding this script to the repo? Should I make a PR to create a new file using this trick?
The text was updated successfully, but these errors were encountered: