Skip to content

Commit

Permalink
Add policy documentation links to policy_kwargs parameter (#266)
Browse files Browse the repository at this point in the history
* Add policy documentation links to policy_kwargs parameter

* Sort `__all__`

---------

Co-authored-by: Antonin RAFFIN <[email protected]>
  • Loading branch information
kplers and araffin authored Dec 2, 2024
1 parent 36c21ac commit e1ca24a
Show file tree
Hide file tree
Showing 15 changed files with 17 additions and 15 deletions.
1 change: 1 addition & 0 deletions docs/modules/ppo_mask.rst
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ Parameters
:members:
:inherited-members:

.. _ppo_mask_policies:

MaskablePPO Policies
--------------------
Expand Down
1 change: 1 addition & 0 deletions docs/modules/ppo_recurrent.rst
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ Parameters
:members:
:inherited-members:

.. _ppo_recurrent_policies:

RecurrentPPO Policies
---------------------
Expand Down
6 changes: 3 additions & 3 deletions sb3_contrib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@

__all__ = [
"ARS",
"CrossQ",
"MaskablePPO",
"RecurrentPPO",
"QRDQN",
"TQC",
"TRPO",
"CrossQ",
"MaskablePPO",
"RecurrentPPO",
]
2 changes: 1 addition & 1 deletion sb3_contrib/ars/ars.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class ARS(BaseAlgorithm):
:param zero_policy: Boolean determining if the passed policy should have it's weights zeroed before training.
:param alive_bonus_offset: Constant added to the reward at each step, used to cancel out alive bonuses.
:param n_eval_episodes: Number of episodes to evaluate each candidate.
:param policy_kwargs: Keyword arguments to pass to the policy on creation
:param policy_kwargs: Keyword arguments to pass to the policy on creation. See :ref:`ars_policies`
:param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average
the reported success rate, mean episode length, and mean reward over
:param tensorboard_log: String with the directory to put tensorboard logs:
Expand Down
2 changes: 1 addition & 1 deletion sb3_contrib/common/torch_layers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch

__all__ = ["BatchRenorm1d", "BatchRenorm"]
__all__ = ["BatchRenorm", "BatchRenorm1d"]


class BatchRenorm(torch.nn.Module):
Expand Down
2 changes: 1 addition & 1 deletion sb3_contrib/crossq/crossq.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class CrossQ(OffPolicyAlgorithm):
:param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average
the reported success rate, mean episode length, and mean reward over
:param tensorboard_log: the log location for tensorboard (if None, no logging)
:param policy_kwargs: additional arguments to be passed to the policy on creation
:param policy_kwargs: additional arguments to be passed to the policy on creation. See :ref:`crossq_policies`
:param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for
debug messages
:param seed: Seed for the pseudo random generators
Expand Down
2 changes: 1 addition & 1 deletion sb3_contrib/ppo_mask/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from sb3_contrib.ppo_mask.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
from sb3_contrib.ppo_mask.ppo_mask import MaskablePPO

__all__ = ["CnnPolicy", "MlpPolicy", "MultiInputPolicy", "MaskablePPO"]
__all__ = ["CnnPolicy", "MaskablePPO", "MlpPolicy", "MultiInputPolicy"]
2 changes: 1 addition & 1 deletion sb3_contrib/ppo_mask/ppo_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class MaskablePPO(OnPolicyAlgorithm):
:param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average
the reported success rate, mean episode length, and mean reward over
:param tensorboard_log: the log location for tensorboard (if None, no logging)
:param policy_kwargs: additional arguments to be passed to the policy on creation
:param policy_kwargs: additional arguments to be passed to the policy on creation. See :ref:`ppo_mask_policies`
:param verbose: the verbosity level: 0 no output, 1 info, 2 debug
:param seed: Seed for the pseudo random generators
:param device: Device (cpu, cuda, ...) on which the code should be run.
Expand Down
2 changes: 1 addition & 1 deletion sb3_contrib/ppo_recurrent/ppo_recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class RecurrentPPO(OnPolicyAlgorithm):
:param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average
the reported success rate, mean episode length, and mean reward over
:param tensorboard_log: the log location for tensorboard (if None, no logging)
:param policy_kwargs: additional arguments to be passed to the policy on creation
:param policy_kwargs: additional arguments to be passed to the policy on creation. See :ref:`ppo_recurrent_policies`
:param verbose: the verbosity level: 0 no output, 1 info, 2 debug
:param seed: Seed for the pseudo random generators
:param device: Device (cpu, cuda, ...) on which the code should be run.
Expand Down
2 changes: 1 addition & 1 deletion sb3_contrib/qrdqn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from sb3_contrib.qrdqn.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
from sb3_contrib.qrdqn.qrdqn import QRDQN

__all__ = ["CnnPolicy", "MlpPolicy", "MultiInputPolicy", "QRDQN"]
__all__ = ["QRDQN", "CnnPolicy", "MlpPolicy", "MultiInputPolicy"]
2 changes: 1 addition & 1 deletion sb3_contrib/qrdqn/qrdqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class QRDQN(OffPolicyAlgorithm):
:param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average
the reported success rate, mean episode length, and mean reward over
:param tensorboard_log: the log location for tensorboard (if None, no logging)
:param policy_kwargs: additional arguments to be passed to the policy on creation
:param policy_kwargs: additional arguments to be passed to the policy on creation. See :ref:`qrdqn_policies`
:param verbose: the verbosity level: 0 no output, 1 info, 2 debug
:param seed: Seed for the pseudo random generators
:param device: Device (cpu, cuda, ...) on which the code should be run.
Expand Down
2 changes: 1 addition & 1 deletion sb3_contrib/tqc/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from sb3_contrib.tqc.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
from sb3_contrib.tqc.tqc import TQC

__all__ = ["CnnPolicy", "MlpPolicy", "MultiInputPolicy", "TQC"]
__all__ = ["TQC", "CnnPolicy", "MlpPolicy", "MultiInputPolicy"]
2 changes: 1 addition & 1 deletion sb3_contrib/tqc/tqc.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class TQC(OffPolicyAlgorithm):
:param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average
the reported success rate, mean episode length, and mean reward over
:param tensorboard_log: the log location for tensorboard (if None, no logging)
:param policy_kwargs: additional arguments to be passed to the policy on creation
:param policy_kwargs: additional arguments to be passed to the policy on creation. See :ref:`tqc_policies`
:param verbose: the verbosity level: 0 no output, 1 info, 2 debug
:param seed: Seed for the pseudo random generators
:param device: Device (cpu, cuda, ...) on which the code should be run.
Expand Down
2 changes: 1 addition & 1 deletion sb3_contrib/trpo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from sb3_contrib.trpo.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
from sb3_contrib.trpo.trpo import TRPO

__all__ = ["CnnPolicy", "MlpPolicy", "MultiInputPolicy", "TRPO"]
__all__ = ["TRPO", "CnnPolicy", "MlpPolicy", "MultiInputPolicy"]
2 changes: 1 addition & 1 deletion sb3_contrib/trpo/trpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class TRPO(OnPolicyAlgorithm):
:param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average
the reported success rate, mean episode length, and mean reward over
:param tensorboard_log: the log location for tensorboard (if None, no logging)
:param policy_kwargs: additional arguments to be passed to the policy on creation
:param policy_kwargs: additional arguments to be passed to the policy on creation. See :ref:`trpo_policies`
:param verbose: the verbosity level: 0 no output, 1 info, 2 debug
:param seed: Seed for the pseudo random generators
:param device: Device (cpu, cuda, ...) on which the code should be run.
Expand Down

0 comments on commit e1ca24a

Please sign in to comment.