diff --git a/sb3_contrib/__init__.py b/sb3_contrib/__init__.py index 3fbd28d8..c4e02f64 100644 --- a/sb3_contrib/__init__.py +++ b/sb3_contrib/__init__.py @@ -1,6 +1,7 @@ import os from sb3_contrib.ars import ARS +from sb3_contrib.dueling_dqn import DuelingDQN from sb3_contrib.ppo_mask import MaskablePPO from sb3_contrib.ppo_recurrent import RecurrentPPO from sb3_contrib.qrdqn import QRDQN @@ -14,6 +15,7 @@ __all__ = [ "ARS", + "DuelingDQN" "MaskablePPO", "RecurrentPPO", "QRDQN", diff --git a/sb3_contrib/dueling_dqn/__init__.py b/sb3_contrib/dueling_dqn/__init__.py new file mode 100644 index 00000000..4243fae0 --- /dev/null +++ b/sb3_contrib/dueling_dqn/__init__.py @@ -0,0 +1,4 @@ +from sb3_contrib.dueling_dqn.dueling_dqn import DuelingDQN +from sb3_contrib.dueling_dqn.policies import CnnPolicy, MlpPolicy, MultiInputPolicy + +__all__ = ["DuelingDQN", "CnnPolicy", "MlpPolicy", "MultiInputPolicy"] diff --git a/sb3_contrib/dueling_dqn/dueling_dqn.py b/sb3_contrib/dueling_dqn/dueling_dqn.py new file mode 100644 index 00000000..1e4c9f8b --- /dev/null +++ b/sb3_contrib/dueling_dqn/dueling_dqn.py @@ -0,0 +1,125 @@ +from typing import Any, Dict, Optional, Tuple, Type, TypeVar, Union + +import torch as th +from stable_baselines3.common.buffers import ReplayBuffer +from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule +from stable_baselines3.dqn.dqn import DQN + +from sb3_contrib.dueling_dqn.policies import DuelingDQNPolicy + +SelfDuelingDQN = TypeVar("SelfDuelingDQN", bound="DuelingDQN") + + +class DuelingDQN(DQN): + """ + Dueling Deep Q-Network (Dueling DQN) + + Paper: https://arxiv.org/abs/1511.06581 + + :param policy: The policy model to use (MlpPolicy, CnnPolicy, ...) + :param env: The environment to learn from (if registered in Gym, can be str) + :param learning_rate: The learning rate, it can be a function + of the current progress remaining (from 1 to 0) + :param buffer_size: size of the replay buffer + :param learning_starts: how many steps of the model to collect transitions for before learning starts + :param batch_size: Minibatch size for each gradient update + :param tau: the soft update coefficient ("Polyak update", between 0 and 1) default 1 for hard update + :param gamma: the discount factor + :param train_freq: Update the model every ``train_freq`` steps. Alternatively pass a tuple of frequency and unit + like ``(5, "step")`` or ``(2, "episode")``. + :param gradient_steps: How many gradient steps to do after each rollout (see ``train_freq``) + Set to ``-1`` means to do as many gradient steps as steps done in the environment + during the rollout. + :param replay_buffer_class: Replay buffer class to use (for instance ``HerReplayBuffer``). + If ``None``, it will be automatically selected. + :param replay_buffer_kwargs: Keyword arguments to pass to the replay buffer on creation. + :param optimize_memory_usage: Enable a memory efficient variant of the replay buffer + at a cost of more complexity. + See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195 + :param target_update_interval: update the target network every ``target_update_interval`` + environment steps. + :param exploration_fraction: fraction of entire training period over which the exploration rate is reduced + :param exploration_initial_eps: initial value of random action probability + :param exploration_final_eps: final value of random action probability + :param max_grad_norm: The maximum value for the gradient clipping + :param tensorboard_log: the log location for tensorboard (if None, no logging) + :param policy_kwargs: additional arguments to be passed to the policy on creation + :param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for + debug messages + :param seed: Seed for the pseudo random generators + :param device: Device (cpu, cuda, ...) on which the code should be run. + Setting it to auto, the code will be run on the GPU if possible. + :param _init_setup_model: Whether or not to build the network at the creation of the instance + """ + + def __init__( + self, + policy: Union[str, Type[DuelingDQNPolicy]], + env: Union[GymEnv, str], + learning_rate: Union[float, Schedule] = 0.0001, + buffer_size: int = 1000000, + learning_starts: int = 50000, + batch_size: int = 32, + tau: float = 1, + gamma: float = 0.99, + train_freq: Union[int, Tuple[int, str]] = 4, + gradient_steps: int = 1, + replay_buffer_class: Optional[Type[ReplayBuffer]] = None, + replay_buffer_kwargs: Optional[Dict[str, Any]] = None, + optimize_memory_usage: bool = False, + target_update_interval: int = 10000, + exploration_fraction: float = 0.1, + exploration_initial_eps: float = 1, + exploration_final_eps: float = 0.05, + max_grad_norm: float = 10, + tensorboard_log: Optional[str] = None, + policy_kwargs: Optional[Dict[str, Any]] = None, + verbose: int = 0, + seed: Optional[int] = None, + device: Union[th.device, str] = "auto", + _init_setup_model: bool = True, + ): + super().__init__( + policy, + env, + learning_rate, + buffer_size, + learning_starts, + batch_size, + tau, + gamma, + train_freq, + gradient_steps, + replay_buffer_class, + replay_buffer_kwargs, + optimize_memory_usage, + target_update_interval, + exploration_fraction, + exploration_initial_eps, + exploration_final_eps, + max_grad_norm, + tensorboard_log, + policy_kwargs, + verbose, + seed, + device, + _init_setup_model, + ) + + def learn( + self: SelfDuelingDQN, + total_timesteps: int, + callback: MaybeCallback = None, + log_interval: int = 4, + tb_log_name: str = "DuelingDQN", + reset_num_timesteps: bool = True, + progress_bar: bool = False, + ) -> SelfDuelingDQN: + return super().learn( + total_timesteps, + callback, + log_interval, + tb_log_name, + reset_num_timesteps, + progress_bar, + ) diff --git a/sb3_contrib/dueling_dqn/policies.py b/sb3_contrib/dueling_dqn/policies.py new file mode 100644 index 00000000..6f87b56f --- /dev/null +++ b/sb3_contrib/dueling_dqn/policies.py @@ -0,0 +1,182 @@ +from typing import Any, Dict, List, Optional, Type + +import gym +import torch as th +from stable_baselines3.common.torch_layers import BaseFeaturesExtractor, CombinedExtractor, NatureCNN, create_mlp +from stable_baselines3.common.type_aliases import Schedule +from stable_baselines3.dqn.policies import DQNPolicy, QNetwork +from torch import nn + + +class DuelingQNetwork(QNetwork): + """ + Dueling Q-Network. + + :param observation_space: Observation space + :param action_space: Action space + :param net_arch: The specification of the policy and value networks. + :param activation_fn: Activation function + :param normalize_images: Whether to normalize images or not, + dividing by 255.0 (True by default) + """ + + def __init__( + self, + observation_space: gym.spaces.Space, + action_space: gym.spaces.Space, + features_extractor: nn.Module, + features_dim: int, + net_arch: Optional[List[int]] = None, + activation_fn: Type[nn.Module] = nn.ReLU, + normalize_images: bool = True, + ): + super().__init__( + observation_space, + action_space, + features_extractor, + features_dim, + net_arch, + activation_fn, + normalize_images, + ) + + if net_arch is None: + net_arch = [64, 64] + + action_dim = self.action_space.n # number of actions + value_stream = create_mlp(self.features_dim, 1, self.net_arch, self.activation_fn) + self.value_stream = nn.Sequential(*value_stream) + advantage_stream = create_mlp(self.features_dim, action_dim, self.net_arch, self.activation_fn) + self.advantage_stream = nn.Sequential(*advantage_stream) + + def forward(self, obs: th.Tensor) -> th.Tensor: + """ + Predict the q-values. + + :param obs: Observation + :return: The estimated Q-Value for each action. + """ + features = self.extract_features(obs) + values = self.value_stream(features) + advantages = self.advantage_stream(features) + qvals = values + (advantages - advantages.mean()) + return qvals + + +class DuelingDQNPolicy(DQNPolicy): + """ + Policy class for Dueling DQN. + + :param observation_space: Observation space + :param action_space: Action space + :param lr_schedule: Learning rate schedule (could be constant) + :param net_arch: The specification of the policy and value networks. + :param activation_fn: Activation function + :param features_extractor_class: Features extractor to use. + :param features_extractor_kwargs: Keyword arguments + to pass to the features extractor. + :param normalize_images: Whether to normalize images or not, + dividing by 255.0 (True by default) + :param optimizer_class: The optimizer to use, + ``th.optim.Adam`` by default + :param optimizer_kwargs: Additional keyword arguments, + excluding the learning rate, to pass to the optimizer + """ + + def make_q_net(self) -> DuelingQNetwork: + # Make sure we always have separate networks for features extractors etc + net_args = self._update_features_extractor(self.net_args, features_extractor=None) + return DuelingQNetwork(**net_args).to(self.device) + + +MlpPolicy = DuelingDQNPolicy + + +class CnnPolicy(DuelingDQNPolicy): + """ + Policy class for Dueling DQN when using images as input. + + :param observation_space: Observation space + :param action_space: Action space + :param lr_schedule: Learning rate schedule (could be constant) + :param net_arch: The specification of the policy and value networks. + :param activation_fn: Activation function + :param features_extractor_class: Features extractor to use. + :param normalize_images: Whether to normalize images or not, + dividing by 255.0 (True by default) + :param optimizer_class: The optimizer to use, + ``th.optim.Adam`` by default + :param optimizer_kwargs: Additional keyword arguments, + excluding the learning rate, to pass to the optimizer + """ + + def __init__( + self, + observation_space: gym.spaces.Space, + action_space: gym.spaces.Space, + lr_schedule: Schedule, + net_arch: Optional[List[int]] = None, + activation_fn: Type[nn.Module] = nn.ReLU, + features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN, + features_extractor_kwargs: Optional[Dict[str, Any]] = None, + normalize_images: bool = True, + optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[Dict[str, Any]] = None, + ): + super().__init__( + observation_space, + action_space, + lr_schedule, + net_arch, + activation_fn, + features_extractor_class, + features_extractor_kwargs, + normalize_images, + optimizer_class, + optimizer_kwargs, + ) + + +class MultiInputPolicy(DuelingDQNPolicy): + """ + Policy class for Dueling DQN when using dict observations as input. + + :param observation_space: Observation space + :param action_space: Action space + :param lr_schedule: Learning rate schedule (could be constant) + :param net_arch: The specification of the policy and value networks. + :param activation_fn: Activation function + :param features_extractor_class: Features extractor to use. + :param normalize_images: Whether to normalize images or not, + dividing by 255.0 (True by default) + :param optimizer_class: The optimizer to use, + ``th.optim.Adam`` by default + :param optimizer_kwargs: Additional keyword arguments, + excluding the learning rate, to pass to the optimizer + """ + + def __init__( + self, + observation_space: gym.spaces.Dict, + action_space: gym.spaces.Space, + lr_schedule: Schedule, + net_arch: Optional[List[int]] = None, + activation_fn: Type[nn.Module] = nn.ReLU, + features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractor, + features_extractor_kwargs: Optional[Dict[str, Any]] = None, + normalize_images: bool = True, + optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[Dict[str, Any]] = None, + ): + super().__init__( + observation_space, + action_space, + lr_schedule, + net_arch, + activation_fn, + features_extractor_class, + features_extractor_kwargs, + normalize_images, + optimizer_class, + optimizer_kwargs, + )