Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DuelingDQN #126

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions sb3_contrib/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os

from sb3_contrib.ars import ARS
from sb3_contrib.dueling_dqn import DuelingDQN
from sb3_contrib.ppo_mask import MaskablePPO
from sb3_contrib.ppo_recurrent import RecurrentPPO
from sb3_contrib.qrdqn import QRDQN
Expand All @@ -14,6 +15,7 @@

__all__ = [
"ARS",
"DuelingDQN"
"MaskablePPO",
"RecurrentPPO",
"QRDQN",
Expand Down
4 changes: 4 additions & 0 deletions sb3_contrib/dueling_dqn/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from sb3_contrib.dueling_dqn.dueling_dqn import DuelingDQN
from sb3_contrib.dueling_dqn.policies import CnnPolicy, MlpPolicy, MultiInputPolicy

__all__ = ["DuelingDQN", "CnnPolicy", "MlpPolicy", "MultiInputPolicy"]
125 changes: 125 additions & 0 deletions sb3_contrib/dueling_dqn/dueling_dqn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
from typing import Any, Dict, Optional, Tuple, Type, TypeVar, Union

import torch as th
from stable_baselines3.common.buffers import ReplayBuffer
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.dqn.dqn import DQN

from sb3_contrib.dueling_dqn.policies import DuelingDQNPolicy

SelfDuelingDQN = TypeVar("SelfDuelingDQN", bound="DuelingDQN")


class DuelingDQN(DQN):
"""
Dueling Deep Q-Network (Dueling DQN)

Paper: https://arxiv.org/abs/1511.06581

:param policy: The policy model to use (MlpPolicy, CnnPolicy, ...)
:param env: The environment to learn from (if registered in Gym, can be str)
:param learning_rate: The learning rate, it can be a function
of the current progress remaining (from 1 to 0)
:param buffer_size: size of the replay buffer
:param learning_starts: how many steps of the model to collect transitions for before learning starts
:param batch_size: Minibatch size for each gradient update
:param tau: the soft update coefficient ("Polyak update", between 0 and 1) default 1 for hard update
:param gamma: the discount factor
:param train_freq: Update the model every ``train_freq`` steps. Alternatively pass a tuple of frequency and unit
like ``(5, "step")`` or ``(2, "episode")``.
:param gradient_steps: How many gradient steps to do after each rollout (see ``train_freq``)
Set to ``-1`` means to do as many gradient steps as steps done in the environment
during the rollout.
:param replay_buffer_class: Replay buffer class to use (for instance ``HerReplayBuffer``).
If ``None``, it will be automatically selected.
:param replay_buffer_kwargs: Keyword arguments to pass to the replay buffer on creation.
:param optimize_memory_usage: Enable a memory efficient variant of the replay buffer
at a cost of more complexity.
See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195
:param target_update_interval: update the target network every ``target_update_interval``
environment steps.
:param exploration_fraction: fraction of entire training period over which the exploration rate is reduced
:param exploration_initial_eps: initial value of random action probability
:param exploration_final_eps: final value of random action probability
:param max_grad_norm: The maximum value for the gradient clipping
:param tensorboard_log: the log location for tensorboard (if None, no logging)
:param policy_kwargs: additional arguments to be passed to the policy on creation
:param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for
debug messages
:param seed: Seed for the pseudo random generators
:param device: Device (cpu, cuda, ...) on which the code should be run.
Setting it to auto, the code will be run on the GPU if possible.
:param _init_setup_model: Whether or not to build the network at the creation of the instance
"""

def __init__(
self,
policy: Union[str, Type[DuelingDQNPolicy]],
env: Union[GymEnv, str],
learning_rate: Union[float, Schedule] = 0.0001,
buffer_size: int = 1000000,
learning_starts: int = 50000,
batch_size: int = 32,
tau: float = 1,
gamma: float = 0.99,
train_freq: Union[int, Tuple[int, str]] = 4,
gradient_steps: int = 1,
replay_buffer_class: Optional[Type[ReplayBuffer]] = None,
replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
optimize_memory_usage: bool = False,
target_update_interval: int = 10000,
exploration_fraction: float = 0.1,
exploration_initial_eps: float = 1,
exploration_final_eps: float = 0.05,
max_grad_norm: float = 10,
tensorboard_log: Optional[str] = None,
policy_kwargs: Optional[Dict[str, Any]] = None,
verbose: int = 0,
seed: Optional[int] = None,
device: Union[th.device, str] = "auto",
_init_setup_model: bool = True,
):
super().__init__(
policy,
env,
learning_rate,
buffer_size,
learning_starts,
batch_size,
tau,
gamma,
train_freq,
gradient_steps,
replay_buffer_class,
replay_buffer_kwargs,
optimize_memory_usage,
target_update_interval,
exploration_fraction,
exploration_initial_eps,
exploration_final_eps,
max_grad_norm,
tensorboard_log,
policy_kwargs,
verbose,
seed,
device,
_init_setup_model,
)

def learn(
self: SelfDuelingDQN,
total_timesteps: int,
callback: MaybeCallback = None,
log_interval: int = 4,
tb_log_name: str = "DuelingDQN",
reset_num_timesteps: bool = True,
progress_bar: bool = False,
) -> SelfDuelingDQN:
return super().learn(
total_timesteps,
callback,
log_interval,
tb_log_name,
reset_num_timesteps,
progress_bar,
)
182 changes: 182 additions & 0 deletions sb3_contrib/dueling_dqn/policies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
from typing import Any, Dict, List, Optional, Type

import gym
import torch as th
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor, CombinedExtractor, NatureCNN, create_mlp
from stable_baselines3.common.type_aliases import Schedule
from stable_baselines3.dqn.policies import DQNPolicy, QNetwork
from torch import nn


class DuelingQNetwork(QNetwork):
"""
Dueling Q-Network.

:param observation_space: Observation space
:param action_space: Action space
:param net_arch: The specification of the policy and value networks.
:param activation_fn: Activation function
:param normalize_images: Whether to normalize images or not,
dividing by 255.0 (True by default)
"""

def __init__(
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
features_extractor: nn.Module,
features_dim: int,
net_arch: Optional[List[int]] = None,
activation_fn: Type[nn.Module] = nn.ReLU,
normalize_images: bool = True,
):
super().__init__(
observation_space,
action_space,
features_extractor,
features_dim,
net_arch,
activation_fn,
normalize_images,
)

if net_arch is None:
net_arch = [64, 64]

action_dim = self.action_space.n # number of actions
value_stream = create_mlp(self.features_dim, 1, self.net_arch, self.activation_fn)
self.value_stream = nn.Sequential(*value_stream)
advantage_stream = create_mlp(self.features_dim, action_dim, self.net_arch, self.activation_fn)
self.advantage_stream = nn.Sequential(*advantage_stream)

def forward(self, obs: th.Tensor) -> th.Tensor:
"""
Predict the q-values.

:param obs: Observation
:return: The estimated Q-Value for each action.
"""
features = self.extract_features(obs)
values = self.value_stream(features)
advantages = self.advantage_stream(features)
qvals = values + (advantages - advantages.mean())
return qvals


class DuelingDQNPolicy(DQNPolicy):
"""
Policy class for Dueling DQN.

:param observation_space: Observation space
:param action_space: Action space
:param lr_schedule: Learning rate schedule (could be constant)
:param net_arch: The specification of the policy and value networks.
:param activation_fn: Activation function
:param features_extractor_class: Features extractor to use.
:param features_extractor_kwargs: Keyword arguments
to pass to the features extractor.
:param normalize_images: Whether to normalize images or not,
dividing by 255.0 (True by default)
:param optimizer_class: The optimizer to use,
``th.optim.Adam`` by default
:param optimizer_kwargs: Additional keyword arguments,
excluding the learning rate, to pass to the optimizer
"""

def make_q_net(self) -> DuelingQNetwork:
# Make sure we always have separate networks for features extractors etc
net_args = self._update_features_extractor(self.net_args, features_extractor=None)
return DuelingQNetwork(**net_args).to(self.device)


MlpPolicy = DuelingDQNPolicy


class CnnPolicy(DuelingDQNPolicy):
"""
Policy class for Dueling DQN when using images as input.

:param observation_space: Observation space
:param action_space: Action space
:param lr_schedule: Learning rate schedule (could be constant)
:param net_arch: The specification of the policy and value networks.
:param activation_fn: Activation function
:param features_extractor_class: Features extractor to use.
:param normalize_images: Whether to normalize images or not,
dividing by 255.0 (True by default)
:param optimizer_class: The optimizer to use,
``th.optim.Adam`` by default
:param optimizer_kwargs: Additional keyword arguments,
excluding the learning rate, to pass to the optimizer
"""

def __init__(
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
lr_schedule: Schedule,
net_arch: Optional[List[int]] = None,
activation_fn: Type[nn.Module] = nn.ReLU,
features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN,
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
normalize_images: bool = True,
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
):
super().__init__(
observation_space,
action_space,
lr_schedule,
net_arch,
activation_fn,
features_extractor_class,
features_extractor_kwargs,
normalize_images,
optimizer_class,
optimizer_kwargs,
)


class MultiInputPolicy(DuelingDQNPolicy):
"""
Policy class for Dueling DQN when using dict observations as input.

:param observation_space: Observation space
:param action_space: Action space
:param lr_schedule: Learning rate schedule (could be constant)
:param net_arch: The specification of the policy and value networks.
:param activation_fn: Activation function
:param features_extractor_class: Features extractor to use.
:param normalize_images: Whether to normalize images or not,
dividing by 255.0 (True by default)
:param optimizer_class: The optimizer to use,
``th.optim.Adam`` by default
:param optimizer_kwargs: Additional keyword arguments,
excluding the learning rate, to pass to the optimizer
"""

def __init__(
self,
observation_space: gym.spaces.Dict,
action_space: gym.spaces.Space,
lr_schedule: Schedule,
net_arch: Optional[List[int]] = None,
activation_fn: Type[nn.Module] = nn.ReLU,
features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractor,
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
normalize_images: bool = True,
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
):
super().__init__(
observation_space,
action_space,
lr_schedule,
net_arch,
activation_fn,
features_extractor_class,
features_extractor_kwargs,
normalize_images,
optimizer_class,
optimizer_kwargs,
)