Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SACD Discrete Soft Actor Critic #203

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ RL Baselines3 Zoo also offers a simple interface to train, evaluate agents and d
modules/ppo_mask
modules/ppo_recurrent
modules/qrdqn
modules/sacd
modules/tqc
modules/trpo

Expand Down
99 changes: 99 additions & 0 deletions docs/modules/sacd.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
.. _sacd:

.. automodule:: sb3_contrib.sacd


SACD
====


`Soft Actor Critic Discrete (SACD) <https://arxiv.org/abs/1910.07207>`_ is a modification of the original Soft Actor Critic Algorithm for discrete action spaces.

.. rubric:: Available Policies

.. autosummary::
:nosignatures:

MlpPolicy
CnnPolicy
MultiInputPolicy


Notes
-----

- Original paper: https://arxiv.org/abs/1910.07207
- Original Implementation: https://github.com/p-christ/Deep-Reinforcement-Learning-Algorithms-with-PyTorch


Can I use?
----------

- Recurrent policies: ❌
- Multi processing: ✔️
- Gym spaces:


============= ====== ===========
Space Action Observation
============= ====== ===========
Discrete ✔️ ✔️
Box ❌ ✔️
MultiDiscrete ❌ ✔️
MultiBinary ❌ ✔️
Dict ❌ ✔️
============= ====== ===========


Example
-------
.. code-block:: python

import gymnasium as gym

from sb3_contrib import SACD

env = gym.make("CartPole-v1", render_mode="rgb_array")

model = SACD("MlpPolicy", env, verbose=1, policy_kwargs=dict(net_arch=[64,64]))
model.learn(total_timesteps=20_000)
model.save("sacd_cartpole")

del model # remove to demonstrate saving and loading

model = SACD.load("sac_cartpole")

obs, info = env.reset()
while True:
action, _states = model.predict(obs, deterministic=True)
obs, reward, terminated, truncated, info = env.step(action)
if terminated or truncated:
obs, info = env.reset()



Parameters
----------

.. autoclass:: SACD
:members:
:inherited-members:

.. _sac_policies:

SACD Policies
-------------

.. autoclass:: MlpPolicy
:members:
:inherited-members:

.. autoclass:: stable_baselines3.sac.policies.SACPolicy
:members:
:noindex:

.. autoclass:: CnnPolicy
:members:

.. autoclass:: MultiInputPolicy
:members:
2 changes: 2 additions & 0 deletions sb3_contrib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from sb3_contrib.ppo_mask import MaskablePPO
from sb3_contrib.ppo_recurrent import RecurrentPPO
from sb3_contrib.qrdqn import QRDQN
from sb3_contrib.sacd import SACD
from sb3_contrib.tqc import TQC
from sb3_contrib.trpo import TRPO

Expand All @@ -19,4 +20,5 @@
"QRDQN",
"TQC",
"TRPO",
"SACD",
]
4 changes: 4 additions & 0 deletions sb3_contrib/sacd/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from sb3_contrib.sacd.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
from sb3_contrib.sacd.sacd import SACD

__all__ = ["CnnPolicy", "MlpPolicy", "MultiInputPolicy", "SACD"]
Loading