Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Recurrent PPO #53

Merged
merged 60 commits into from
May 30, 2022
Merged
Show file tree
Hide file tree
Changes from 35 commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
85c9a50
Running (not working yet) version of recurrent PPO
araffin Nov 22, 2021
b92da74
Fixes for multi envs
araffin Nov 23, 2021
d9f9c4e
Save WIP, rework the sampling
araffin Nov 23, 2021
97ec8ec
Add Box support
araffin Nov 23, 2021
a890976
Fix sample order
araffin Nov 23, 2021
7fecd9f
Being cleanup, code is broken (again)
araffin Nov 25, 2021
0ddc3f6
First working version (no shared lstm)
araffin Nov 25, 2021
5ef313b
Start cleanup
araffin Nov 25, 2021
c803ac9
Try rnn with value function
araffin Nov 25, 2021
0c8ab15
Re-enable batch size
araffin Nov 25, 2021
eb1e6c1
Deactivate vf rnn
araffin Nov 25, 2021
f013346
Allow any batch size
araffin Nov 25, 2021
a14f2ce
Add support for evaluation
araffin Nov 26, 2021
362dec4
Add CNN support
araffin Nov 26, 2021
5b162db
Fix start of sequence
araffin Nov 26, 2021
954e6dd
Allow shared LSTM
araffin Nov 26, 2021
832093d
Rename mask to episode_start
araffin Nov 28, 2021
2a9c956
Fix type hint
araffin Nov 28, 2021
15c080a
Enable LSTM for critic
araffin Nov 28, 2021
0d304aa
Clean code
araffin Nov 28, 2021
1dc78b4
Fix for CNN LSTM
araffin Nov 28, 2021
deaa7b4
Fix sampling with n_layers > 1
araffin Nov 28, 2021
ced6aee
Add std logger
araffin Nov 29, 2021
b81fdff
Update wording
araffin Nov 30, 2021
a2a201f
Merge branch 'master' into feat/ppo-lstm
araffin Dec 1, 2021
754e0a3
Merge branch 'master' into feat/ppo-lstm
araffin Dec 10, 2021
c9c0b4e
Rename and add dict obs support
araffin Dec 27, 2021
a4b769f
Fixes for dict obs support
araffin Dec 27, 2021
5cadc14
Do not run slow tests
araffin Dec 27, 2021
617d76f
Merge branch 'master' into feat/ppo-lstm
araffin Dec 29, 2021
c1f8812
Fix doc
araffin Dec 29, 2021
579e7d0
Update recurrent PPO example
araffin Dec 29, 2021
bd2d5e2
Update README
araffin Dec 29, 2021
072622b
Merge branch 'master' into feat/ppo-lstm
araffin Jan 3, 2022
c113324
Merge branch 'master' into feat/ppo-lstm
araffin Jan 19, 2022
4adb3ea
Merge branch 'master' into feat/ppo-lstm
araffin Feb 22, 2022
32ec1b2
Merge branch 'master' into feat/ppo-lstm
araffin Feb 23, 2022
0f0ce0b
Use Pendulum-v1 for tests
araffin Feb 23, 2022
116d0a6
Fix image env
araffin Feb 23, 2022
c32bb74
Speedup LSTM forward pass (#63)
Walon1998 Mar 8, 2022
638dfb2
Merge branch 'master' into feat/ppo-lstm
araffin Apr 12, 2022
86e0f6f
Fixes
araffin Apr 12, 2022
3fc6e51
Remove OpenAI sampling and improve coverage
araffin Apr 12, 2022
88f9504
Sync with SB3 PPO
araffin Apr 12, 2022
662f218
Pass state shape and allow lstm kwargs
araffin Apr 12, 2022
fd06850
Update tests
araffin Apr 12, 2022
f5e9b34
Add masking for padded sequences
araffin Apr 12, 2022
1cd27da
Update default in perf test
araffin Apr 12, 2022
c52959b
Remove TODO, mask is now working
araffin Apr 15, 2022
18e6230
Merge branch 'master' into feat/ppo-lstm
araffin Apr 25, 2022
673d23a
Add helper to remove duplicated code, remove hack for padding
araffin May 1, 2022
e271d03
Enable LSTM critic and raise threshold for cartpole with no vel
araffin May 8, 2022
73bb89c
Fix tests
araffin May 8, 2022
efa6181
Update doc and tests
araffin May 18, 2022
564d428
Doc fix
araffin May 18, 2022
408ed24
Fix for new Sphinx version
araffin May 29, 2022
d917487
Merge branch 'master' into feat/ppo-lstm
araffin May 29, 2022
6acb64a
Fix doc note
araffin May 29, 2022
5fd8be7
Switch to batch first, no more additional swap
araffin May 30, 2022
7a1d3e8
Add comments and mask entropy loss
araffin May 30, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,12 @@ We hope this allows us to provide reliable implementations following stable-base
See documentation for the full list of included features.

**RL Algorithms**:
- [Truncated Quantile Critics (TQC)](https://arxiv.org/abs/2005.04269)
- [Augmented Random Search (ARS)](https://arxiv.org/abs/1803.07055)
- [Quantile Regression DQN (QR-DQN)](https://arxiv.org/abs/1710.10044)
- [PPO with invalid action masking (MaskablePPO)](https://arxiv.org/abs/2006.14171)
- [PPO with recurrent policy (RecurrentPPO)](https://arxiv.org/abs/1707.06347)
- [Truncated Quantile Critics (TQC)](https://arxiv.org/abs/2005.04269)
- [Trust Region Policy Optimization (TRPO)](https://arxiv.org/abs/1502.05477)
- [Augmented Random Search (ARS)](https://arxiv.org/abs/1803.07055)

**Gym Wrappers**:
- [Time Feature Wrapper](https://arxiv.org/abs/1712.00378)
Expand Down
2 changes: 2 additions & 0 deletions docs/guide/algos.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ along with some useful characteristics: support for discrete/continuous actions,
Name ``Box`` ``Discrete`` ``MultiDiscrete`` ``MultiBinary`` Multi Processing
============ =========== ============ ================= =============== ================
ARS ✔️ ❌️ ❌ ❌ ✔️
MaskablePPO ❌ ✔️ ✔️ ✔️ ✔️
QR-DQN ️❌ ️✔️ ❌ ❌ ✔️
RecurrentPPO ✔️ ✔️ ✔️ ✔️ ✔️
TQC ✔️ ❌ ❌ ❌ ✔️
TRPO ✔️ ✔️ ✔️ ✔️ ✔️
============ =========== ============ ================= =============== ================
Expand Down
25 changes: 25 additions & 0 deletions docs/guide/examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,28 @@ Train an agent using Augmented Random Search (ARS) agent on the Pendulum environ
model = ARS("LinearPolicy", "Pendulum-v0", verbose=1)
model.learn(total_timesteps=10000, log_interval=4)
model.save("ars_pendulum")

RecurrentPPO
------------

Train a PPO agent with a recurrent policy on the CartPole environment.

.. code-block:: python

import numpy as np

from sb3_contrib import RecurrentPPO

model = RecurrentPPO("MlpLstmPolicy", "CartPole-v1", verbose=1)
model.learn(5000)

env = model.get_env()
obs = env.reset()
lstm_states = None
num_envs = 1
episode_starts = np.ones((num_envs,), dtype=bool)
while True:
action, lstm_states = model.predict(obs, state=lstm_states, episode_start=episode_starts, deterministic=True)
obs, rewards, dones, info = env.step(action)
episode_starts = dones
env.render()
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ RL Baselines3 Zoo also offers a simple interface to train, evaluate agents and d

modules/ars
modules/ppo_mask
modules/ppo_recurrent
modules/qrdqn
modules/tqc
modules/trpo
Expand Down
23 changes: 23 additions & 0 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,29 @@
Changelog
==========

Release 1.4.1a0 (WIP)
-------------------------------
**Add Recurrent PPO**

Breaking Changes:
^^^^^^^^^^^^^^^^^

New Features:
^^^^^^^^^^^^^
- Added ``RecurrentPPO``

Bug Fixes:
^^^^^^^^^^

Deprecations:
^^^^^^^^^^^^^

Others:
^^^^^^^

Documentation:
^^^^^^^^^^^^^^


Release 1.4.0 (2022-01-19)
-------------------------------
Expand Down
2 changes: 1 addition & 1 deletion docs/modules/ppo_mask.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
Maskable PPO
============

Implementation of `invalid action masking <https://arxiv.org/abs/2006.14171>`_ for the Proximal Policy Optimization(PPO)
Implementation of `invalid action masking <https://arxiv.org/abs/2006.14171>`_ for the Proximal Policy Optimization (PPO)
algorithm. Other than adding support for action masking, the behavior is the same as in SB3's core PPO algorithm.


Expand Down
127 changes: 127 additions & 0 deletions docs/modules/ppo_recurrent.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
.. _ppo_mask:

.. automodule:: sb3_contrib.ppo_recurrent

Recurrent PPO
=============

Implementation of recurrent policies for the Proximal Policy Optimization (PPO)
algorithm. Other than adding support for recurrent policies (LSTM here), the behavior is the same as in SB3's core PPO algorithm.


.. rubric:: Available Policies

.. autosummary::
:nosignatures:

MlpLstmPolicy
CnnLstmPolicy
MultiInputLstmPolicy


Notes
-----

.. - Paper: https://arxiv.org/abs/2006.14171
.. - Blog post: https://costa.sh/blog-a-closer-look-at-invalid-action-masking-in-policy-gradient-algorithms.html


Can I use?
----------

- Recurrent policies: ✔️
- Multi processing: ✔️
- Gym spaces:


============= ====== ===========
Space Action Observation
============= ====== ===========
Discrete ✔️ ✔️
Box ✔️ ✔️
MultiDiscrete ✔️ ✔️
MultiBinary ✔️ ✔️
Dict ❌ ✔️
============= ====== ===========


Example
-------


.. code-block:: python

import numpy as np

from sb3_contrib import RecurrentPPO
from stable_baselines3.common.evaluation import evaluate_policy

model = RecurrentPPO("MlpLstmPolicy", "CartPole-v1", verbose=1)
model.learn(5000)

env = model.get_env()
mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=20, warn=False)
print(mean_reward)

model.save("ppo_recurrent")
del model # remove to demonstrate saving and loading

model = RecurrentPPO.load("ppo_recurrent")

obs = env.reset()
lstm_states = None
num_envs = 1
episode_starts = np.ones((num_envs,), dtype=bool)
while True:
action, lstm_states = model.predict(obs, state=lstm_states, episode_start=episode_starts, deterministic=True)
obs, rewards, dones, info = env.step(action)
episode_starts = dones
env.render()



Results
-------

How to replicate the results?
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Clone the repo for the experiment:

.. code-block:: bash

git clone https://github.com/DLR-RM/rl-baselines3-zoo
git checkout feat/recurrent-ppo

Parameters
----------

.. autoclass:: RecurrentPPO
:members:
:inherited-members:


RecurrentPPO Policies
---------------------

.. autoclass:: MlpLstmPolicy
:members:
:inherited-members:

.. autoclass:: sb3_contrib.common.recurrent.policies.RecurrentActorCriticPolicy
:members:
:noindex:

.. autoclass:: CnnLstmPolicy
:members:

.. autoclass:: sb3_contrib.common.recurrent.policies.RecurrentActorCriticCnnPolicy
:members:
:noindex:

.. autoclass:: MultiInputLstmPolicy
:members:

.. autoclass:: sb3_contrib.common.recurrent.policies.RecurrentMultiInputActorCriticPolicy
:members:
:noindex:
1 change: 1 addition & 0 deletions sb3_contrib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from sb3_contrib.ars import ARS
from sb3_contrib.ppo_mask import MaskablePPO
from sb3_contrib.ppo_recurrent import RecurrentPPO
from sb3_contrib.qrdqn import QRDQN
from sb3_contrib.tqc import TQC
from sb3_contrib.trpo import TRPO
Expand Down
10 changes: 5 additions & 5 deletions sb3_contrib/common/maskable/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,12 +215,12 @@ def predict(
action_masks: Optional[np.ndarray] = None,
) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
"""
Get the policy action and state from an observation (and optional state).
Get the policy action from an observation (and optional hidden state).
Includes sugar-coating to handle different observations (e.g. normalizing images).

:param observation: the input observation
:param state: The last states (can be None, used in recurrent policies)
:param mask: The last masks (can be None, used in recurrent policies)
:param episode_start: The last masks (can be None, used in recurrent policies)
:param deterministic: Whether or not to return deterministic actions.
:param action_masks: Action masks to apply to the action distribution
:return: the model's action and the next state
Expand All @@ -229,8 +229,8 @@ def predict(
# TODO (GH/1): add support for RNN policies
# if state is None:
# state = self.initial_state
# if mask is None:
# mask = [False for _ in range(self.n_envs)]
# if episode_start is None:
# episode_start = [False for _ in range(self.n_envs)]

# Switch to eval mode (this affects batch norm / dropout)
self.set_training_mode(False)
Expand All @@ -256,7 +256,7 @@ def predict(
raise ValueError("Error: The environment must be vectorized when using recurrent policies.")
actions = actions[0]

return actions, state
return actions, None

def evaluate_actions(
self,
Expand Down
Empty file.
Loading