From 85c9a507558b556ee8b36843634ac00c8afaecb1 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Mon, 22 Nov 2021 18:04:26 +0100 Subject: [PATCH 01/50] Running (not working yet) version of recurrent PPO --- sb3_contrib/__init__.py | 1 + sb3_contrib/common/recurrent/__init__.py | 0 sb3_contrib/common/recurrent/buffers.py | 233 +++++++++ sb3_contrib/common/recurrent/policies.py | 171 +++++++ sb3_contrib/common/recurrent/torch_layers.py | 79 +++ sb3_contrib/common/recurrent/utils.py | 0 sb3_contrib/ppo_lstm/__init__.py | 3 + sb3_contrib/ppo_lstm/policies.py | 13 + sb3_contrib/ppo_lstm/ppo_lstm.py | 488 +++++++++++++++++++ setup.cfg | 1 + tests/test_lstm.py | 53 ++ 11 files changed, 1042 insertions(+) create mode 100644 sb3_contrib/common/recurrent/__init__.py create mode 100644 sb3_contrib/common/recurrent/buffers.py create mode 100644 sb3_contrib/common/recurrent/policies.py create mode 100644 sb3_contrib/common/recurrent/torch_layers.py create mode 100644 sb3_contrib/common/recurrent/utils.py create mode 100644 sb3_contrib/ppo_lstm/__init__.py create mode 100644 sb3_contrib/ppo_lstm/policies.py create mode 100644 sb3_contrib/ppo_lstm/ppo_lstm.py create mode 100644 tests/test_lstm.py diff --git a/sb3_contrib/__init__.py b/sb3_contrib/__init__.py index c90336af..4420815d 100644 --- a/sb3_contrib/__init__.py +++ b/sb3_contrib/__init__.py @@ -1,5 +1,6 @@ import os +from sb3_contrib.ppo_lstm import RecurrentPPO from sb3_contrib.ppo_mask import MaskablePPO from sb3_contrib.qrdqn import QRDQN from sb3_contrib.tqc import TQC diff --git a/sb3_contrib/common/recurrent/__init__.py b/sb3_contrib/common/recurrent/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/sb3_contrib/common/recurrent/buffers.py b/sb3_contrib/common/recurrent/buffers.py new file mode 100644 index 00000000..b5f8679a --- /dev/null +++ b/sb3_contrib/common/recurrent/buffers.py @@ -0,0 +1,233 @@ +from typing import Generator, NamedTuple, Optional, Tuple, Union + +import numpy as np +import torch as th +from gym import spaces +from stable_baselines3.common.buffers import DictRolloutBuffer, RolloutBuffer +from stable_baselines3.common.type_aliases import TensorDict +from stable_baselines3.common.vec_env import VecNormalize + + +class RecurrentRolloutBufferSamples(NamedTuple): + observations: th.Tensor + actions: th.Tensor + old_values: th.Tensor + old_log_prob: th.Tensor + advantages: th.Tensor + returns: th.Tensor + lstm_states: Tuple[th.Tensor, th.Tensor] + dones: th.Tensor + + +class RecurrentDictRolloutBufferSamples(RecurrentRolloutBufferSamples): + observations: TensorDict + actions: th.Tensor + old_values: th.Tensor + old_log_prob: th.Tensor + advantages: th.Tensor + returns: th.Tensor + lstm_states: Tuple[th.Tensor, th.Tensor] + dones: th.Tensor + + +class RecurrentRolloutBuffer(RolloutBuffer): + """ + Rollout buffer that also stores the invalid action masks associated with each observation. + + :param buffer_size: Max number of element in the buffer + :param observation_space: Observation space + :param action_space: Action space + :param device: + :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator + Equivalent to classic advantage when set to 1. + :param gamma: Discount factor + :param n_envs: Number of parallel environments + """ + + def __init__( + self, + buffer_size: int, + observation_space: spaces.Space, + action_space: spaces.Space, + lstm_states: Tuple[np.ndarray, np.ndarray], + device: Union[th.device, str] = "cpu", + gae_lambda: float = 1, + gamma: float = 0.99, + n_envs: int = 1, + ): + self.lstm_states = lstm_states + self.dones = None + super().__init__(buffer_size, observation_space, action_space, device, gae_lambda, gamma, n_envs) + + def reset(self): + self.hidden_states = np.zeros_like(self.lstm_states[0]) + self.cell_states = np.zeros_like(self.lstm_states[1]) + self.dones = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) + super().reset() + + def add(self, *args, lstm_states: Tuple[np.ndarray, np.ndarray], dones: np.ndarray, **kwargs) -> None: + """ + :param hidden_states: LSTM cell and hidden state + """ + self.hidden_states[self.pos] = np.array(lstm_states[0]) + self.cell_states[self.pos] = np.array(lstm_states[1]) + self.dones[self.pos] = np.array(dones) + + super().add(*args, **kwargs) + + def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentRolloutBufferSamples, None, None]: + assert self.full, "" + # Do not shuffle + indices = np.arange(self.buffer_size * self.n_envs) + # Prepare the data + if not self.generator_ready: + # hidden_state_shape = (self.n_steps, lstm.num_layers, self.n_envs, lstm.hidden_size) + # swap first to (self.n_steps, self.n_envs, lstm.num_layers, lstm.hidden_size) + self.hidden_states = self.hidden_states.swapaxes(1, 2) + self.cell_states = self.cell_states.swapaxes(1, 2) + + for tensor in [ + "observations", + "actions", + "values", + "log_probs", + "advantages", + "returns", + "hidden_states", + "cell_states", + "dones", + ]: + self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor]) + self.generator_ready = True + + # Return everything, don't create minibatches + if batch_size is None: + batch_size = self.buffer_size * self.n_envs + + start_idx = 0 + while start_idx < self.buffer_size * self.n_envs: + yield self._get_samples(indices[start_idx : start_idx + batch_size]) + start_idx += batch_size + + def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> RecurrentRolloutBufferSamples: + return RecurrentRolloutBufferSamples( + observations=self.to_torch(self.observations[batch_inds]), + actions=self.to_torch(self.actions[batch_inds]), + old_values=self.to_torch(self.values[batch_inds].flatten()), + old_log_prob=self.to_torch(self.log_probs[batch_inds].flatten()), + advantages=self.to_torch(self.advantages[batch_inds].flatten()), + returns=self.to_torch(self.returns[batch_inds].flatten()), + lstm_states=(self.to_torch(self.hidden_states[batch_inds][0]), self.to_torch(self.cell_states[batch_inds][0])), + dones=self.to_torch(self.dones[batch_inds]), + ) + + +class RecurrentDictRolloutBuffer(DictRolloutBuffer): + """ + Dict Rollout buffer used in on-policy algorithms like A2C/PPO. + Extends the RolloutBuffer to use dictionary observations + + It corresponds to ``buffer_size`` transitions collected + using the current policy. + This experience will be discarded after the policy update. + In order to use PPO objective, we also store the current value of each state + and the log probability of each taken action. + + The term rollout here refers to the model-free notion and should not + be used with the concept of rollout used in model-based RL or planning. + Hence, it is only involved in policy and value function training but not action selection. + + :param buffer_size: Max number of element in the buffer + :param observation_space: Observation space + :param action_space: Action space + :param device: + :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator + Equivalent to classic advantage when set to 1. + :param gamma: Discount factor + :param n_envs: Number of parallel environments + """ + + def __init__( + self, + buffer_size: int, + observation_space: spaces.Space, + action_space: spaces.Space, + lstm_states: Tuple[np.ndarray, np.ndarray], + device: Union[th.device, str] = "cpu", + gae_lambda: float = 1, + gamma: float = 0.99, + n_envs: int = 1, + ): + super(RecurrentDictRolloutBuffer, self).__init__( + buffer_size, observation_space, action_space, device, gae_lambda, gamma, n_envs=n_envs + ) + self.lstm_states = lstm_states + self.dones = None + + def reset(self): + self.hidden_states = np.zeros_like(self.lstm_states[0]) + self.cell_states = np.zeros_like(self.lstm_states[1]) + self.dones = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) + super().reset() + + def add(self, *args, lstm_states: Tuple[np.ndarray, np.ndarray], dones: np.ndarray, **kwargs) -> None: + """ + :param hidden_states: LSTM cell and hidden state + """ + self.hidden_states[self.pos] = np.array(lstm_states[0]) + self.cell_states[self.pos] = np.array(lstm_states[1]) + self.dones[self.pos] = np.array(dones) + + super().add(*args, **kwargs) + + def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentDictRolloutBufferSamples, None, None]: + assert self.full, "" + # indices = np.random.permutation(self.buffer_size * self.n_envs) + # Do not shuffle the data + indices = np.arange(self.buffer_size * self.n_envs) + # Prepare the data + if not self.generator_ready: + # hidden_state_shape = (self.n_steps, lstm.num_layers, self.n_envs, lstm.hidden_size) + # swap first to (self.n_steps, self.n_envs, lstm.num_layers, lstm.hidden_size) + self.hidden_states = self.hidden_states.swapaxes(1, 2) + self.cell_states = self.cell_states.swapaxes(1, 2) + + for key, obs in self.observations.items(): + self.observations[key] = self.swap_and_flatten(obs) + + _tensor_names = [ + "actions", + "values", + "log_probs", + "advantages", + "returns", + "hidden_states", + "cell_states", + "dones", + ] + + for tensor in _tensor_names: + self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor]) + self.generator_ready = True + + # Return everything, don't create minibatches + if batch_size is None: + batch_size = self.buffer_size * self.n_envs + + start_idx = 0 + while start_idx < self.buffer_size * self.n_envs: + yield self._get_samples(indices[start_idx : start_idx + batch_size]) + start_idx += batch_size + + def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> RecurrentDictRolloutBufferSamples: + + return RecurrentDictRolloutBufferSamples( + observations={key: self.to_torch(obs[batch_inds]) for (key, obs) in self.observations.items()}, + actions=self.to_torch(self.actions[batch_inds]), + old_values=self.to_torch(self.values[batch_inds].flatten()), + old_log_prob=self.to_torch(self.log_probs[batch_inds].flatten()), + advantages=self.to_torch(self.advantages[batch_inds].flatten()), + returns=self.to_torch(self.returns[batch_inds].flatten()), + lstm_states=(self.to_torch(self.hidden_states[batch_inds]), self.to_torch(self.cell_states[batch_inds])), + dones=self.to_torch(self.dones[batch_inds]), + ) diff --git a/sb3_contrib/common/recurrent/policies.py b/sb3_contrib/common/recurrent/policies.py new file mode 100644 index 00000000..fd57ee54 --- /dev/null +++ b/sb3_contrib/common/recurrent/policies.py @@ -0,0 +1,171 @@ +from typing import Any, Dict, List, Optional, Tuple, Type, Union + +import gym +import numpy as np +import torch as th +from stable_baselines3.common.policies import ActorCriticPolicy +from stable_baselines3.common.torch_layers import BaseFeaturesExtractor +from stable_baselines3.common.type_aliases import Schedule +from torch import nn + +from sb3_contrib.common.recurrent.torch_layers import LSTMExtractor + +# CombinedExtractor,; FlattenExtractor,; MlpExtractor,; NatureCNN,; create_mlp, + + +class RecurrentActorCriticPolicy(ActorCriticPolicy): + """ + CNN policy class for actor-critic algorithms (has both policy and value prediction). + Used by A2C, PPO and the likes. + + :param observation_space: Observation space + :param action_space: Action space + :param lr_schedule: Learning rate schedule (could be constant) + :param net_arch: The specification of the policy and value networks. + :param activation_fn: Activation function + :param ortho_init: Whether to use or not orthogonal initialization + :param use_sde: Whether to use State Dependent Exploration or not + :param log_std_init: Initial value for the log standard deviation + :param full_std: Whether to use (n_features x n_actions) parameters + for the std instead of only (n_features,) when using gSDE + :param sde_net_arch: Network architecture for extracting features + when using gSDE. If None, the latent features from the policy will be used. + Pass an empty list to use the states as features. + :param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure + a positive standard deviation (cf paper). It allows to keep variance + above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough. + :param squash_output: Whether to squash the output using a tanh function, + this allows to ensure boundaries when using gSDE. + :param features_extractor_class: Features extractor to use. + :param features_extractor_kwargs: Keyword arguments + to pass to the features extractor. + :param normalize_images: Whether to normalize images or not, + dividing by 255.0 (True by default) + :param optimizer_class: The optimizer to use, + ``th.optim.Adam`` by default + :param optimizer_kwargs: Additional keyword arguments, + excluding the learning rate, to pass to the optimizer + """ + + def __init__( + self, + observation_space: gym.spaces.Space, + action_space: gym.spaces.Space, + lr_schedule: Schedule, + net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None, + activation_fn: Type[nn.Module] = nn.Tanh, + ortho_init: bool = True, + use_sde: bool = False, + log_std_init: float = 0.0, + full_std: bool = True, + sde_net_arch: Optional[List[int]] = None, + use_expln: bool = False, + squash_output: bool = False, + features_extractor_class: Type[BaseFeaturesExtractor] = LSTMExtractor, + features_extractor_kwargs: Optional[Dict[str, Any]] = None, + normalize_images: bool = True, + optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[Dict[str, Any]] = None, + ): + super().__init__( + observation_space, + action_space, + lr_schedule, + net_arch, + activation_fn, + ortho_init, + use_sde, + log_std_init, + full_std, + sde_net_arch, + use_expln, + squash_output, + features_extractor_class, + features_extractor_kwargs, + normalize_images, + optimizer_class, + optimizer_kwargs, + ) + + def _predict( + self, + observation: th.Tensor, + deterministic: bool = False, + lstm_states: Optional[Tuple[th.Tensor, th.Tensor]] = None, + ) -> th.Tensor: + """ + Get the action according to the policy for a given observation. + + :param observation: + :param lstm_states: + :param deterministic: Whether to use stochastic or deterministic actions + :return: Taken action according to the policy + """ + self.features_extractor.set_lstm_states(lstm_states) + return self.get_distribution(observation).get_actions(deterministic=deterministic) + + def get_lstm_states(self) -> Tuple[th.Tensor, th.Tensor]: + return self.features_extractor.lstm_states + + def set_lstm_states(self, lstm_states: Optional[Tuple[th.Tensor]] = None) -> None: + self.features_extractor.set_lstm_states(lstm_states) + + def set_dones(self, dones: th.Tensor) -> None: + self.features_extractor.set_dones(dones) + + def predict( + self, + observation: Union[np.ndarray, Dict[str, np.ndarray]], + state: Optional[np.ndarray] = None, + mask: Optional[np.ndarray] = None, + deterministic: bool = False, + ) -> Tuple[np.ndarray, Optional[np.ndarray]]: + """ + Get the policy action and state from an observation (and optional state). + Includes sugar-coating to handle different observations (e.g. normalizing images). + + :param observation: the input observation + :param state: The last states (can be None, used in recurrent policies) + :param mask: The last masks (can be None, used in recurrent policies) + :param deterministic: Whether or not to return deterministic actions. + :return: the model's action and the next state + (used in recurrent policies) + """ + # TODO (GH/1): add support for RNN policies + # if state is None: + # state = self.features_extractor.initial_lstm_states + # if mask is None: + # mask = [False for _ in range(self.n_envs)] + # Switch to eval mode (this affects batch norm / dropout) + self.set_training_mode(False) + + observation, vectorized_env = self.obs_to_tensor(observation) + + # TODO(antonin): preprocess state + # if state is not None: + # lstm_states = None + + with th.no_grad(): + actions = self._predict(observation, lstm_states=state, deterministic=deterministic) + states = self.get_lstm_states() + states = (states[0].cpu().numpy(), states[1].cpu().numpy()) + # TODO(antonin): fix eval script + states = None + + # Convert to numpy + actions = actions.cpu().numpy() + + if isinstance(self.action_space, gym.spaces.Box): + if self.squash_output: + # Rescale to proper domain when using squashing + actions = self.unscale_action(actions) + else: + # Actions could be on arbitrary scale, so clip the actions to avoid + # out of bound error (e.g. if sampling from a Gaussian distribution) + actions = np.clip(actions, self.action_space.low, self.action_space.high) + + # Remove batch dimension if needed + if not vectorized_env: + actions = actions[0] + + return actions, states diff --git a/sb3_contrib/common/recurrent/torch_layers.py b/sb3_contrib/common/recurrent/torch_layers.py new file mode 100644 index 00000000..3a0f739f --- /dev/null +++ b/sb3_contrib/common/recurrent/torch_layers.py @@ -0,0 +1,79 @@ +from copy import deepcopy +from typing import Optional, Tuple + +import gym +import torch as th +from stable_baselines3.common.preprocessing import get_flattened_obs_dim +from stable_baselines3.common.torch_layers import BaseFeaturesExtractor +from torch import nn + + +class LSTMExtractor(BaseFeaturesExtractor): + """ + Feature extract that pass the data through an LSTM after flattening it. + Used as a placeholder when feature extraction is not needed. + + :param observation_space: + """ + + def __init__(self, observation_space: gym.Space, hidden_size: int = 64, num_layers: int = 1): + super().__init__(observation_space, hidden_size) + self.flatten = nn.Flatten() + self.lstm = nn.LSTM(get_flattened_obs_dim(observation_space), hidden_size, num_layers=num_layers) + # One forward pass to initial hidden state + # dummy_cell_state, dummy_hidden = self.lstm() + # Cell and hidden state + n_envs = 1 + self.initial_hidden_state = (th.zeros(num_layers, n_envs, hidden_size), th.zeros(num_layers, n_envs, hidden_size)) + self._lstm_states = deepcopy(self.initial_hidden_state) + self.dones = None + + def reset_state(self) -> None: + self._lstm_states = deepcopy(self.initial_hidden_state) + self.dones = None + + def process_sequence( + self, + observations: th.Tensor, + lstm_states: Tuple[th.Tensor, th.Tensor], + dones: th.Tensor, + ) -> Tuple[th.Tensor, th.Tensor]: + features = self.flatten(observations) + + # LSTM logic + # (sequence length, batch size, features dim) (batch size = n envs) + batch_size = lstm_states[0].shape[1] + features_sequence = features.reshape((-1, batch_size, self.lstm.input_size)) + dones = dones.reshape((-1, batch_size)) + lstm_output = [] + # Iterate over the sequence + for features, done in zip(features_sequence, dones): + hidden, lstm_states = self.lstm( + features.unsqueeze(0), + ( + (1.0 - done).view(1, -1, 1) * lstm_states[0], + (1.0 - done).view(1, -1, 1) * lstm_states[1], + ), + ) + lstm_output += [hidden] + lstm_output = th.flatten(th.cat(lstm_output), start_dim=0, end_dim=1) + return lstm_output, lstm_states + + def set_lstm_states(self, lstm_states: Optional[Tuple[th.Tensor]] = None) -> None: + if lstm_states is None: + self.reset_state() + else: + self._lstm_states = deepcopy(lstm_states) + + def set_dones(self, dones: th.Tensor) -> None: + self.dones = dones + + @property + def lstm_states(self) -> Tuple[th.Tensor, th.Tensor]: + return self._lstm_states + + def forward(self, observations: th.Tensor) -> th.Tensor: + if self.dones is None: + self.dones = th.zeros(len(observations)).float().to(observations.device) + features, self._lstm_states = self.process_sequence(observations, self._lstm_states, self.dones) + return features diff --git a/sb3_contrib/common/recurrent/utils.py b/sb3_contrib/common/recurrent/utils.py new file mode 100644 index 00000000..e69de29b diff --git a/sb3_contrib/ppo_lstm/__init__.py b/sb3_contrib/ppo_lstm/__init__.py new file mode 100644 index 00000000..92333335 --- /dev/null +++ b/sb3_contrib/ppo_lstm/__init__.py @@ -0,0 +1,3 @@ +# from sb3_contrib.ppo_lstm.policies import CnnPolicy, MlpPolicy, MultiInputPolicy +from sb3_contrib.ppo_lstm.policies import MlpLstmPolicy +from sb3_contrib.ppo_lstm.ppo_lstm import RecurrentPPO diff --git a/sb3_contrib/ppo_lstm/policies.py b/sb3_contrib/ppo_lstm/policies.py new file mode 100644 index 00000000..d0d4d7f3 --- /dev/null +++ b/sb3_contrib/ppo_lstm/policies.py @@ -0,0 +1,13 @@ +from stable_baselines3.common.policies import register_policy + +from sb3_contrib.common.recurrent.policies import ( # RecurrentActorCriticCnnPolicy,; RecurrentMultiInputActorCriticPolicy, + RecurrentActorCriticPolicy, +) + +MlpLstmPolicy = RecurrentActorCriticPolicy +# CnnLstmPolicy = RecurrentActorCriticCnnPolicy +# MultiInputLstmPolicy = RecurrentMultiInputActorCriticPolicy + +register_policy("MlpLstmPolicy", RecurrentActorCriticPolicy) +# register_policy("CnnLstmPolicy", RecurrentActorCriticCnnPolicy) +# register_policy("MultiInputLstmPolicy", RecurrentMultiInputActorCriticPolicy) diff --git a/sb3_contrib/ppo_lstm/ppo_lstm.py b/sb3_contrib/ppo_lstm/ppo_lstm.py new file mode 100644 index 00000000..700e1842 --- /dev/null +++ b/sb3_contrib/ppo_lstm/ppo_lstm.py @@ -0,0 +1,488 @@ +import time +from typing import Any, Dict, Optional, Tuple, Type, Union + +import gym +import numpy as np +import torch as th +from gym import spaces +from stable_baselines3.common.buffers import RolloutBuffer +from stable_baselines3.common.callbacks import BaseCallback +from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm +from stable_baselines3.common.policies import ActorCriticPolicy +from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule +from stable_baselines3.common.utils import explained_variance, get_schedule_fn, obs_as_tensor, safe_mean +from stable_baselines3.common.vec_env import VecEnv +from torch.nn import functional as F + +from sb3_contrib.common.recurrent.buffers import RecurrentDictRolloutBuffer, RecurrentRolloutBuffer +from sb3_contrib.common.recurrent.policies import RecurrentActorCriticPolicy + + +class RecurrentPPO(OnPolicyAlgorithm): + """ + Proximal Policy Optimization algorithm (PPO) (clip version) + with support for recurrent policies (LSTM). + + Based on the original Stable Baselines 3 implementation. + + Introduction to PPO: https://spinningup.openai.com/en/latest/algorithms/ppo.html + + :param policy: The policy model to use (MlpPolicy, CnnPolicy, ...) + :param env: The environment to learn from (if registered in Gym, can be str) + :param learning_rate: The learning rate, it can be a function + of the current progress remaining (from 1 to 0) + :param n_steps: The number of steps to run for each environment per update + (i.e. batch size is n_steps * n_env where n_env is number of environment copies running in parallel) + :param batch_size: Minibatch size + :param n_epochs: Number of epoch when optimizing the surrogate loss + :param gamma: Discount factor + :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator + :param clip_range: Clipping parameter, it can be a function of the current progress + remaining (from 1 to 0). + :param clip_range_vf: Clipping parameter for the value function, + it can be a function of the current progress remaining (from 1 to 0). + This is a parameter specific to the OpenAI implementation. If None is passed (default), + no clipping will be done on the value function. + IMPORTANT: this clipping depends on the reward scaling. + :param ent_coef: Entropy coefficient for the loss calculation + :param vf_coef: Value function coefficient for the loss calculation + :param max_grad_norm: The maximum value for the gradient clipping + :param target_kl: Limit the KL divergence between updates, + because the clipping is not enough to prevent large update + see issue #213 (cf https://github.com/hill-a/stable-baselines/issues/213) + By default, there is no limit on the kl div. + :param tensorboard_log: the log location for tensorboard (if None, no logging) + :param create_eval_env: Whether to create a second environment that will be + used for evaluating the agent periodically. (Only available when passing string for the environment) + :param policy_kwargs: additional arguments to be passed to the policy on creation + :param verbose: the verbosity level: 0 no output, 1 info, 2 debug + :param seed: Seed for the pseudo random generators + :param device: Device (cpu, cuda, ...) on which the code should be run. + Setting it to auto, the code will be run on the GPU if possible. + :param _init_setup_model: Whether or not to build the network at the creation of the instance + """ + + def __init__( + self, + policy: Union[str, Type[RecurrentActorCriticPolicy]], + env: Union[GymEnv, str], + learning_rate: Union[float, Schedule] = 3e-4, + n_steps: int = 128, + batch_size: Optional[int] = 128, + n_epochs: int = 10, + gamma: float = 0.99, + gae_lambda: float = 0.95, + clip_range: Union[float, Schedule] = 0.2, + clip_range_vf: Union[None, float, Schedule] = None, + ent_coef: float = 0.0, + vf_coef: float = 0.5, + max_grad_norm: float = 0.5, + target_kl: Optional[float] = None, + tensorboard_log: Optional[str] = None, + create_eval_env: bool = False, + policy_kwargs: Optional[Dict[str, Any]] = None, + verbose: int = 0, + seed: Optional[int] = None, + device: Union[th.device, str] = "auto", + _init_setup_model: bool = True, + ): + super().__init__( + policy, + env, + learning_rate=learning_rate, + n_steps=n_steps, + gamma=gamma, + gae_lambda=gae_lambda, + ent_coef=ent_coef, + vf_coef=vf_coef, + max_grad_norm=max_grad_norm, + # TODO(antonin): add gSDE support + use_sde=False, + sde_sample_freq=-1, + tensorboard_log=tensorboard_log, + create_eval_env=create_eval_env, + policy_kwargs=policy_kwargs, + policy_base=ActorCriticPolicy, + verbose=verbose, + seed=seed, + device=device, + _init_setup_model=False, + supported_action_spaces=( + spaces.Discrete, + spaces.MultiDiscrete, + spaces.MultiBinary, + ), + ) + + self.batch_size = batch_size + self.n_epochs = n_epochs + self.clip_range = clip_range + self.clip_range_vf = clip_range_vf + self.target_kl = target_kl + self.lstm_states = None + + if _init_setup_model: + self._setup_model() + + def _setup_model(self) -> None: + self._setup_lr_schedule() + self.set_random_seed(self.seed) + + buffer_cls = ( + RecurrentDictRolloutBuffer if isinstance(self.observation_space, gym.spaces.Dict) else RecurrentRolloutBuffer + ) + + self.policy = self.policy_class( + self.observation_space, + self.action_space, + self.lr_schedule, + **self.policy_kwargs, # pytype:disable=not-instantiable + ) + self.policy = self.policy.to(self.device) + + if not isinstance(self.policy, RecurrentActorCriticPolicy): + raise ValueError("Policy must subclass RecurrentActorCriticPolicy") + + # TODO: handle multiple envs + # self.lstm_states = (th.tensor(lstm_states[0]).to(self.device), th.tensor(lstm_states[0]).to(self.device)) + lstm = self.policy.features_extractor.lstm + hidden_state_shape = (self.n_steps, lstm.num_layers, self.n_envs, lstm.hidden_size) + lstm_states = (np.zeros(hidden_state_shape, dtype=np.float32), np.zeros(hidden_state_shape, dtype=np.float32)) + + self.rollout_buffer = buffer_cls( + self.n_steps, + self.observation_space, + self.action_space, + lstm_states, + self.device, + gamma=self.gamma, + gae_lambda=self.gae_lambda, + n_envs=self.n_envs, + ) + + # Initialize schedules for policy/value clipping + self.clip_range = get_schedule_fn(self.clip_range) + if self.clip_range_vf is not None: + if isinstance(self.clip_range_vf, (float, int)): + assert self.clip_range_vf > 0, "`clip_range_vf` must be positive, " "pass `None` to deactivate vf clipping" + + self.clip_range_vf = get_schedule_fn(self.clip_range_vf) + + def _setup_learn( + self, + total_timesteps: int, + eval_env: Optional[GymEnv], + callback: MaybeCallback = None, + eval_freq: int = 10000, + n_eval_episodes: int = 5, + log_path: Optional[str] = None, + reset_num_timesteps: bool = True, + tb_log_name: str = "run", + ) -> Tuple[int, BaseCallback]: + """ + Initialize different variables needed for training. + + :param total_timesteps: The total number of samples (env steps) to train on + :param eval_env: Environment to use for evaluation. + :param callback: Callback(s) called at every step with state of the algorithm. + :param eval_freq: How many steps between evaluations + :param n_eval_episodes: How many episodes to play per evaluation + :param log_path: Path to a folder where the evaluations will be saved + :param reset_num_timesteps: Whether to reset or not the ``num_timesteps`` attribute + :param tb_log_name: the name of the run for tensorboard log + :return: + """ + + total_timesteps, callback = super()._setup_learn( + total_timesteps, + eval_env, + callback, + eval_freq, + n_eval_episodes, + log_path, + reset_num_timesteps, + tb_log_name, + ) + return total_timesteps, callback + + def collect_rollouts( + self, + env: VecEnv, + callback: BaseCallback, + rollout_buffer: RolloutBuffer, + n_rollout_steps: int, + ) -> bool: + """ + Collect experiences using the current policy and fill a ``RolloutBuffer``. + The term rollout here refers to the model-free notion and should not + be used with the concept of rollout used in model-based RL or planning. + + :param env: The training environment + :param callback: Callback that will be called at each step + (and at the beginning and end of the rollout) + :param rollout_buffer: Buffer to fill with rollouts + :param n_steps: Number of experiences to collect per environment + :return: True if function returned with at least `n_rollout_steps` + collected, False if callback terminated rollout prematurely. + """ + assert isinstance( + rollout_buffer, (RecurrentRolloutBuffer, RecurrentDictRolloutBuffer) + ), "RolloutBuffer doesn't support recurrent policy" + + assert self._last_obs is not None, "No previous observation was provided" + # Switch to eval mode (this affects batch norm / dropout) + self.policy.set_training_mode(False) + + n_steps = 0 + rollout_buffer.reset() + # Sample new weights for the state dependent exploration + if self.use_sde: + self.policy.reset_noise(env.num_envs) + + callback.on_rollout_start() + + while n_steps < n_rollout_steps: + if self.use_sde and self.sde_sample_freq > 0 and n_steps % self.sde_sample_freq == 0: + # Sample a new noise matrix + self.policy.reset_noise(env.num_envs) + + with th.no_grad(): + # Convert to pytorch tensor or to TensorDict + obs_tensor = obs_as_tensor(self._last_obs, self.device) + self.policy.set_lstm_states(self.lstm_states) + actions, values, log_probs = self.policy.forward(obs_tensor) + lstm_states = self.policy.get_lstm_states() + actions = actions.cpu().numpy() + + # Rescale and perform action + clipped_actions = actions + # Clip the actions to avoid out of bound error + if isinstance(self.action_space, gym.spaces.Box): + clipped_actions = np.clip(actions, self.action_space.low, self.action_space.high) + + new_obs, rewards, dones, infos = env.step(clipped_actions) + + self.num_timesteps += env.num_envs + + # Give access to local variables + callback.update_locals(locals()) + if callback.on_step() is False: + return False + + self._update_info_buffer(infos) + n_steps += 1 + + if isinstance(self.action_space, gym.spaces.Discrete): + # Reshape in case of discrete action + actions = actions.reshape(-1, 1) + + # Handle timeout by bootstraping with value function + # see GitHub issue #633 + for idx, done_ in enumerate(dones): + if ( + done_ + and infos[idx].get("terminal_observation") is not None + and infos[idx].get("TimeLimit.truncated", False) + ): + terminal_obs = self.policy.obs_to_tensor(infos[idx]["terminal_observation"])[0] + with th.no_grad(): + terminal_value = self.policy.predict_values(terminal_obs)[0] + rewards[idx] += self.gamma * terminal_value + + rollout_buffer.add( + self._last_obs, + actions, + rewards, + self._last_episode_starts, + values, + log_probs, + lstm_states=(lstm_states[0].cpu().numpy(), lstm_states[1].cpu().numpy()), + dones=dones, + ) + + self._last_obs = new_obs + self._last_episode_starts = dones + # Reset states if needed + for idx, done_ in enumerate(dones): + if done_: + lstm_states[0][:, idx, :] = 0.0 + lstm_states[1][:, idx, :] = 0.0 + self.lstm_states = lstm_states + + with th.no_grad(): + # Compute value for the last timestep + values = self.policy.predict_values(obs_as_tensor(new_obs, self.device)) + + rollout_buffer.compute_returns_and_advantage(last_values=values, dones=dones) + + callback.on_rollout_end() + + return True + + def train(self) -> None: + """ + Update policy using the currently gathered rollout buffer. + """ + # Switch to train mode (this affects batch norm / dropout) + self.policy.set_training_mode(True) + # Update optimizer learning rate + self._update_learning_rate(self.policy.optimizer) + # Compute current clip range + clip_range = self.clip_range(self._current_progress_remaining) + # Optional: clip range for the value function + if self.clip_range_vf is not None: + clip_range_vf = self.clip_range_vf(self._current_progress_remaining) + + entropy_losses = [] + pg_losses, value_losses = [], [] + clip_fractions = [] + + continue_training = True + + # train for n_epochs epochs + for epoch in range(self.n_epochs): + approx_kl_divs = [] + # Do a complete pass on the rollout buffer + for rollout_data in self.rollout_buffer.get(self.batch_size): + actions = rollout_data.actions + if isinstance(self.action_space, spaces.Discrete): + # Convert discrete action from float to long + actions = rollout_data.actions.long().flatten() + + self.policy.set_lstm_states(rollout_data.lstm_states) + self.policy.set_dones(rollout_data.dones) + values, log_prob, entropy = self.policy.evaluate_actions( + rollout_data.observations, + actions, + ) + + values = values.flatten() + # Normalize advantage + advantages = rollout_data.advantages + advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) + + # ratio between old and new policy, should be one at the first iteration + ratio = th.exp(log_prob - rollout_data.old_log_prob) + + # clipped surrogate loss + policy_loss_1 = advantages * ratio + policy_loss_2 = advantages * th.clamp(ratio, 1 - clip_range, 1 + clip_range) + policy_loss = -th.min(policy_loss_1, policy_loss_2).mean() + + # Logging + pg_losses.append(policy_loss.item()) + clip_fraction = th.mean((th.abs(ratio - 1) > clip_range).float()).item() + clip_fractions.append(clip_fraction) + + if self.clip_range_vf is None: + # No clipping + values_pred = values + else: + # Clip the different between old and new value + # NOTE: this depends on the reward scaling + values_pred = rollout_data.old_values + th.clamp( + values - rollout_data.old_values, -clip_range_vf, clip_range_vf + ) + # Value loss using the TD(gae_lambda) target + value_loss = F.mse_loss(rollout_data.returns, values_pred) + value_losses.append(value_loss.item()) + + # Entropy loss favor exploration + if entropy is None: + # Approximate entropy when no analytical form + entropy_loss = -th.mean(-log_prob) + else: + entropy_loss = -th.mean(entropy) + + entropy_losses.append(entropy_loss.item()) + + loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss + + # Calculate approximate form of reverse KL Divergence for early stopping + # see issue #417: https://github.com/DLR-RM/stable-baselines3/issues/417 + # and discussion in PR #419: https://github.com/DLR-RM/stable-baselines3/pull/419 + # and Schulman blog: http://joschu.net/blog/kl-approx.html + with th.no_grad(): + log_ratio = log_prob - rollout_data.old_log_prob + approx_kl_div = th.mean((th.exp(log_ratio) - 1) - log_ratio).cpu().numpy() + approx_kl_divs.append(approx_kl_div) + + if self.target_kl is not None and approx_kl_div > 1.5 * self.target_kl: + continue_training = False + if self.verbose >= 1: + print(f"Early stopping at step {epoch} due to reaching max kl: {approx_kl_div:.2f}") + break + + # Optimization step + self.policy.optimizer.zero_grad() + loss.backward() + # Clip grad norm + th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm) + self.policy.optimizer.step() + + if not continue_training: + break + + self._n_updates += self.n_epochs + explained_var = explained_variance(self.rollout_buffer.values.flatten(), self.rollout_buffer.returns.flatten()) + + # Logs + self.logger.record("train/entropy_loss", np.mean(entropy_losses)) + self.logger.record("train/policy_gradient_loss", np.mean(pg_losses)) + self.logger.record("train/value_loss", np.mean(value_losses)) + self.logger.record("train/approx_kl", np.mean(approx_kl_divs)) + self.logger.record("train/clip_fraction", np.mean(clip_fractions)) + self.logger.record("train/loss", loss.item()) + self.logger.record("train/explained_variance", explained_var) + self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard") + self.logger.record("train/clip_range", clip_range) + if self.clip_range_vf is not None: + self.logger.record("train/clip_range_vf", clip_range_vf) + + def learn( + self, + total_timesteps: int, + callback: MaybeCallback = None, + log_interval: int = 1, + eval_env: Optional[GymEnv] = None, + eval_freq: int = -1, + n_eval_episodes: int = 5, + tb_log_name: str = "RecurrentPPO", + eval_log_path: Optional[str] = None, + reset_num_timesteps: bool = True, + ) -> "RecurrentPPO": + iteration = 0 + + total_timesteps, callback = self._setup_learn( + total_timesteps, eval_env, callback, eval_freq, n_eval_episodes, eval_log_path, reset_num_timesteps, tb_log_name + ) + + callback.on_training_start(locals(), globals()) + + while self.num_timesteps < total_timesteps: + + continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, n_rollout_steps=self.n_steps) + + if continue_training is False: + break + + iteration += 1 + self._update_current_progress_remaining(self.num_timesteps, total_timesteps) + + # Display training infos + if log_interval is not None and iteration % log_interval == 0: + fps = int((self.num_timesteps - self._num_timesteps_at_start) / (time.time() - self.start_time)) + self.logger.record("time/iterations", iteration, exclude="tensorboard") + if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0: + self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer])) + self.logger.record("rollout/ep_len_mean", safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer])) + self.logger.record("time/fps", fps) + self.logger.record("time/time_elapsed", int(time.time() - self.start_time), exclude="tensorboard") + self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard") + self.logger.dump(step=self.num_timesteps) + + self.train() + + callback.on_training_end() + + return self diff --git a/setup.cfg b/setup.cfg index cf162f7b..2e8b5010 100644 --- a/setup.cfg +++ b/setup.cfg @@ -23,6 +23,7 @@ ignore = W503,W504,E203,E231 # line breaks before and after binary operators per-file-ignores = ./sb3_contrib/__init__.py:F401 ./sb3_contrib/ppo_mask/__init__.py:F401 + ./sb3_contrib/ppo_lstm/__init__.py:F401 ./sb3_contrib/qrdqn/__init__.py:F401 ./sb3_contrib/tqc/__init__.py:F401 ./sb3_contrib/common/vec_env/wrappers/__init__.py:F401 diff --git a/tests/test_lstm.py b/tests/test_lstm.py new file mode 100644 index 00000000..eb90d19d --- /dev/null +++ b/tests/test_lstm.py @@ -0,0 +1,53 @@ +import numpy as np +from gym import spaces +from gym.envs.classic_control import CartPoleEnv +from gym.wrappers.time_limit import TimeLimit + +from sb3_contrib import RecurrentPPO + + +class CartPoleNoVelEnv(CartPoleEnv): + """Variant of CartPoleEnv with velocity information removed. This task requires memory to solve.""" + + def __init__(self): + super().__init__() + high = np.array( + [ + self.x_threshold * 2, + self.theta_threshold_radians * 2, + ] + ) + self.observation_space = spaces.Box(-high, high, dtype=np.float32) + + @staticmethod + def _pos_obs(full_obs): + xpos, _xvel, thetapos, _thetavel = full_obs + return xpos, thetapos + + def reset(self): + full_obs = super().reset() + return CartPoleNoVelEnv._pos_obs(full_obs) + + def step(self, action): + full_obs, rew, done, info = super().step(action) + return CartPoleNoVelEnv._pos_obs(full_obs), rew, done, info + + +def test_ppo_lstm(): + from stable_baselines3.common.env_util import make_vec_env + env = make_vec_env("CartPole-v1", n_envs=1) + # env = CartPoleNoVelEnv() + # env = TimeLimit(env, max_episode_steps=500) + + model = RecurrentPPO( + "MlpLstmPolicy", + env, + n_steps=2048, + learning_rate=3e-4, + verbose=1, + batch_size=64, + # create_eval_env=True, + ) + # model.learn(total_timesteps=500, eval_freq=250) + # model.learn(total_timesteps=100_000) + model.learn(total_timesteps=500) From b92da74b553343b6166655bee84811b71c02d4be Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Tue, 23 Nov 2021 16:02:46 +0100 Subject: [PATCH 02/50] Fixes for multi envs --- sb3_contrib/ppo_lstm/ppo_lstm.py | 9 +++++++-- tests/test_lstm.py | 15 +++++++++++++-- 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/sb3_contrib/ppo_lstm/ppo_lstm.py b/sb3_contrib/ppo_lstm/ppo_lstm.py index 700e1842..42d6ebd9 100644 --- a/sb3_contrib/ppo_lstm/ppo_lstm.py +++ b/sb3_contrib/ppo_lstm/ppo_lstm.py @@ -143,11 +143,14 @@ def _setup_model(self) -> None: if not isinstance(self.policy, RecurrentActorCriticPolicy): raise ValueError("Policy must subclass RecurrentActorCriticPolicy") - # TODO: handle multiple envs - # self.lstm_states = (th.tensor(lstm_states[0]).to(self.device), th.tensor(lstm_states[0]).to(self.device)) lstm = self.policy.features_extractor.lstm hidden_state_shape = (self.n_steps, lstm.num_layers, self.n_envs, lstm.hidden_size) lstm_states = (np.zeros(hidden_state_shape, dtype=np.float32), np.zeros(hidden_state_shape, dtype=np.float32)) + single_hidden_state_shape = (lstm.num_layers, self.n_envs, lstm.hidden_size) + self.lstm_states = ( + th.zeros(single_hidden_state_shape).to(self.device), + th.zeros(single_hidden_state_shape).to(self.device), + ) self.rollout_buffer = buffer_cls( self.n_steps, @@ -286,6 +289,8 @@ def collect_rollouts( ): terminal_obs = self.policy.obs_to_tensor(infos[idx]["terminal_observation"])[0] with th.no_grad(): + terminal_lstm_state = lstm_states[0][:, idx : idx + 1, :], lstm_states[1][:, idx : idx + 1, :] + self.policy.set_lstm_states(terminal_lstm_state) terminal_value = self.policy.predict_values(terminal_obs)[0] rewards[idx] += self.gamma * terminal_value diff --git a/tests/test_lstm.py b/tests/test_lstm.py index eb90d19d..432ba759 100644 --- a/tests/test_lstm.py +++ b/tests/test_lstm.py @@ -35,8 +35,18 @@ def step(self, action): def test_ppo_lstm(): from stable_baselines3.common.env_util import make_vec_env - env = make_vec_env("CartPole-v1", n_envs=1) + + env = make_vec_env("CartPole-v1", n_envs=8) + + def make_env(): + env = CartPoleNoVelEnv() + env = TimeLimit(env, max_episode_steps=500) + return env + + env = make_vec_env(make_env, n_envs=2) # env = CartPoleNoVelEnv() + # import gym + # env = gym.make("CartPole-v1") # env = TimeLimit(env, max_episode_steps=500) model = RecurrentPPO( @@ -46,8 +56,9 @@ def test_ppo_lstm(): learning_rate=3e-4, verbose=1, batch_size=64, + seed=0, # create_eval_env=True, ) # model.learn(total_timesteps=500, eval_freq=250) # model.learn(total_timesteps=100_000) - model.learn(total_timesteps=500) + model.learn(total_timesteps=100) From d9f9c4ecef13da013f2c9d44596cc8c9cd021859 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Tue, 23 Nov 2021 18:09:27 +0100 Subject: [PATCH 03/50] Save WIP, rework the sampling --- sb3_contrib/common/recurrent/buffers.py | 54 ++++++++++++++++++++----- tests/test_lstm.py | 8 ++-- 2 files changed, 47 insertions(+), 15 deletions(-) diff --git a/sb3_contrib/common/recurrent/buffers.py b/sb3_contrib/common/recurrent/buffers.py index b5f8679a..3b47b020 100644 --- a/sb3_contrib/common/recurrent/buffers.py +++ b/sb3_contrib/common/recurrent/buffers.py @@ -77,14 +77,13 @@ def add(self, *args, lstm_states: Tuple[np.ndarray, np.ndarray], dones: np.ndarr def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentRolloutBufferSamples, None, None]: assert self.full, "" - # Do not shuffle - indices = np.arange(self.buffer_size * self.n_envs) + # Prepare the data if not self.generator_ready: # hidden_state_shape = (self.n_steps, lstm.num_layers, self.n_envs, lstm.hidden_size) # swap first to (self.n_steps, self.n_envs, lstm.num_layers, lstm.hidden_size) - self.hidden_states = self.hidden_states.swapaxes(1, 2) - self.cell_states = self.cell_states.swapaxes(1, 2) + # self.hidden_states = self.hidden_states.swapaxes(1, 2) + # self.cell_states = self.cell_states.swapaxes(1, 2) for tensor in [ "observations", @@ -93,8 +92,8 @@ def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentRolloutBuf "log_probs", "advantages", "returns", - "hidden_states", - "cell_states", + # "hidden_states", + # "cell_states", "dones", ]: self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor]) @@ -104,10 +103,43 @@ def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentRolloutBuf if batch_size is None: batch_size = self.buffer_size * self.n_envs - start_idx = 0 - while start_idx < self.buffer_size * self.n_envs: - yield self._get_samples(indices[start_idx : start_idx + batch_size]) - start_idx += batch_size + # indices = np.arange(self.buffer_size * self.n_envs) + # start_idx = 0 + # while start_idx < self.buffer_size * self.n_envs: + # yield self._get_samples(indices[start_idx : start_idx + batch_size]) + # start_idx += batch_size + + # Do not shuffle the sequence, only the env indices + n_minibatches = (self.buffer_size * self.n_envs) // batch_size + assert ( + self.n_envs % n_minibatches == 0 + ), f"{self.n_envs} not a multiple of {n_minibatches} = {self.buffer_size * self.n_envs} // {batch_size}" + + n_envs_per_batch = self.n_envs // n_minibatches + env_indices = np.random.permutation(self.n_envs) + + flat_indices = np.arange(self.buffer_size * self.n_envs).reshape(self.n_envs, self.buffer_size) + + + for start_env_idx in range(0, self.n_envs, n_envs_per_batch): + end_env_idx = start_env_idx + n_envs_per_batch + mini_batch_env_indices = env_indices[start_env_idx:end_env_idx] + batch_inds = flat_indices[mini_batch_env_indices].ravel() + lstm_states = ( + self.hidden_states[:, :, mini_batch_env_indices, :][0], + self.cell_states[:, :, mini_batch_env_indices, :][0], + ) + + yield RecurrentRolloutBufferSamples( + observations=self.to_torch(self.observations[batch_inds]), + actions=self.to_torch(self.actions[batch_inds]), + old_values=self.to_torch(self.values[batch_inds].flatten()), + old_log_prob=self.to_torch(self.log_probs[batch_inds].flatten()), + advantages=self.to_torch(self.advantages[batch_inds].flatten()), + returns=self.to_torch(self.returns[batch_inds].flatten()), + lstm_states=(self.to_torch(lstm_states[0]), self.to_torch(lstm_states[1])), + dones=self.to_torch(self.dones[batch_inds]), + ) def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> RecurrentRolloutBufferSamples: return RecurrentRolloutBufferSamples( @@ -117,7 +149,7 @@ def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = Non old_log_prob=self.to_torch(self.log_probs[batch_inds].flatten()), advantages=self.to_torch(self.advantages[batch_inds].flatten()), returns=self.to_torch(self.returns[batch_inds].flatten()), - lstm_states=(self.to_torch(self.hidden_states[batch_inds][0]), self.to_torch(self.cell_states[batch_inds][0])), + lstm_states=(self.to_torch(self.hidden_states[batch_inds][0:1]), self.to_torch(self.cell_states[batch_inds][0:1])), dones=self.to_torch(self.dones[batch_inds]), ) diff --git a/tests/test_lstm.py b/tests/test_lstm.py index 432ba759..d4e85d62 100644 --- a/tests/test_lstm.py +++ b/tests/test_lstm.py @@ -36,14 +36,14 @@ def step(self, action): def test_ppo_lstm(): from stable_baselines3.common.env_util import make_vec_env - env = make_vec_env("CartPole-v1", n_envs=8) + env = make_vec_env("CartPole-v1", n_envs=16) def make_env(): env = CartPoleNoVelEnv() env = TimeLimit(env, max_episode_steps=500) return env - env = make_vec_env(make_env, n_envs=2) + env = make_vec_env(make_env, n_envs=16) # env = CartPoleNoVelEnv() # import gym # env = gym.make("CartPole-v1") @@ -52,10 +52,10 @@ def make_env(): model = RecurrentPPO( "MlpLstmPolicy", env, - n_steps=2048, + n_steps=128, learning_rate=3e-4, verbose=1, - batch_size=64, + batch_size=512, seed=0, # create_eval_env=True, ) From 97ec8ecd0da83b86c198e0ff3848eae605e9f571 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Tue, 23 Nov 2021 23:02:35 +0100 Subject: [PATCH 04/50] Add Box support --- sb3_contrib/ppo_lstm/ppo_lstm.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/sb3_contrib/ppo_lstm/ppo_lstm.py b/sb3_contrib/ppo_lstm/ppo_lstm.py index 42d6ebd9..8ea13adc 100644 --- a/sb3_contrib/ppo_lstm/ppo_lstm.py +++ b/sb3_contrib/ppo_lstm/ppo_lstm.py @@ -77,6 +77,8 @@ def __init__( ent_coef: float = 0.0, vf_coef: float = 0.5, max_grad_norm: float = 0.5, + use_sde: bool = False, + sde_sample_freq: int = -1, target_kl: Optional[float] = None, tensorboard_log: Optional[str] = None, create_eval_env: bool = False, @@ -96,9 +98,8 @@ def __init__( ent_coef=ent_coef, vf_coef=vf_coef, max_grad_norm=max_grad_norm, - # TODO(antonin): add gSDE support - use_sde=False, - sde_sample_freq=-1, + use_sde=use_sde, + sde_sample_freq=sde_sample_freq, tensorboard_log=tensorboard_log, create_eval_env=create_eval_env, policy_kwargs=policy_kwargs, @@ -108,6 +109,7 @@ def __init__( device=device, _init_setup_model=False, supported_action_spaces=( + spaces.Box, spaces.Discrete, spaces.MultiDiscrete, spaces.MultiBinary, @@ -354,6 +356,10 @@ def train(self) -> None: # Convert discrete action from float to long actions = rollout_data.actions.long().flatten() + # Re-sample the noise matrix because the log_std has changed + if self.use_sde: + self.policy.reset_noise(self.batch_size) + self.policy.set_lstm_states(rollout_data.lstm_states) self.policy.set_dones(rollout_data.dones) values, log_prob, entropy = self.policy.evaluate_actions( From a890976f17dd549b3a03bdab1bf9f073419bab26 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Tue, 23 Nov 2021 23:02:57 +0100 Subject: [PATCH 05/50] Fix sample order --- sb3_contrib/common/recurrent/buffers.py | 5 ++--- sb3_contrib/common/recurrent/torch_layers.py | 18 +++++++++++------- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/sb3_contrib/common/recurrent/buffers.py b/sb3_contrib/common/recurrent/buffers.py index 3b47b020..a07e4619 100644 --- a/sb3_contrib/common/recurrent/buffers.py +++ b/sb3_contrib/common/recurrent/buffers.py @@ -115,12 +115,11 @@ def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentRolloutBuf self.n_envs % n_minibatches == 0 ), f"{self.n_envs} not a multiple of {n_minibatches} = {self.buffer_size * self.n_envs} // {batch_size}" - n_envs_per_batch = self.n_envs // n_minibatches + # n_envs_per_batch = self.n_envs // n_minibatches + n_envs_per_batch = batch_size // self.buffer_size env_indices = np.random.permutation(self.n_envs) - flat_indices = np.arange(self.buffer_size * self.n_envs).reshape(self.n_envs, self.buffer_size) - for start_env_idx in range(0, self.n_envs, n_envs_per_batch): end_env_idx = start_env_idx + n_envs_per_batch mini_batch_env_indices = env_indices[start_env_idx:end_env_idx] diff --git a/sb3_contrib/common/recurrent/torch_layers.py b/sb3_contrib/common/recurrent/torch_layers.py index 3a0f739f..68274560 100644 --- a/sb3_contrib/common/recurrent/torch_layers.py +++ b/sb3_contrib/common/recurrent/torch_layers.py @@ -5,6 +5,7 @@ import torch as th from stable_baselines3.common.preprocessing import get_flattened_obs_dim from stable_baselines3.common.torch_layers import BaseFeaturesExtractor +from stable_baselines3.common.utils import zip_strict from torch import nn @@ -41,22 +42,24 @@ def process_sequence( features = self.flatten(observations) # LSTM logic - # (sequence length, batch size, features dim) (batch size = n envs) - batch_size = lstm_states[0].shape[1] - features_sequence = features.reshape((-1, batch_size, self.lstm.input_size)) - dones = dones.reshape((-1, batch_size)) + # (sequence length, n_envs, features dim) (batch size = n envs) + n_envs = lstm_states[0].shape[1] + # Note: order matters and should be consistent with the one from the buffer + # above is when envs are interleaved + features_sequence = features.reshape((n_envs, -1, self.lstm.input_size)).swapaxes(0, 1) + dones = dones.reshape((n_envs, -1)).swapaxes(0, 1) lstm_output = [] # Iterate over the sequence - for features, done in zip(features_sequence, dones): + for features, done in zip_strict(features_sequence, dones): hidden, lstm_states = self.lstm( - features.unsqueeze(0), + features.unsqueeze(dim=0), ( (1.0 - done).view(1, -1, 1) * lstm_states[0], (1.0 - done).view(1, -1, 1) * lstm_states[1], ), ) lstm_output += [hidden] - lstm_output = th.flatten(th.cat(lstm_output), start_dim=0, end_dim=1) + lstm_output = th.flatten(th.cat(lstm_output).transpose(0, 1), start_dim=0, end_dim=1) return lstm_output, lstm_states def set_lstm_states(self, lstm_states: Optional[Tuple[th.Tensor]] = None) -> None: @@ -76,4 +79,5 @@ def forward(self, observations: th.Tensor) -> th.Tensor: if self.dones is None: self.dones = th.zeros(len(observations)).float().to(observations.device) features, self._lstm_states = self.process_sequence(observations, self._lstm_states, self.dones) + self.dones = None return features From 7fecd9fdd7180164c47b0545b2aa122313b5ec7f Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Thu, 25 Nov 2021 11:57:41 +0100 Subject: [PATCH 06/50] Being cleanup, code is broken (again) --- sb3_contrib/common/recurrent/buffers.py | 77 +++++++------ sb3_contrib/common/recurrent/policies.py | 111 +++++++++++++++++-- sb3_contrib/common/recurrent/torch_layers.py | 50 ++++----- sb3_contrib/ppo_lstm/ppo_lstm.py | 54 ++++----- 4 files changed, 188 insertions(+), 104 deletions(-) diff --git a/sb3_contrib/common/recurrent/buffers.py b/sb3_contrib/common/recurrent/buffers.py index a07e4619..295ef6dc 100644 --- a/sb3_contrib/common/recurrent/buffers.py +++ b/sb3_contrib/common/recurrent/buffers.py @@ -16,7 +16,7 @@ class RecurrentRolloutBufferSamples(NamedTuple): advantages: th.Tensor returns: th.Tensor lstm_states: Tuple[th.Tensor, th.Tensor] - dones: th.Tensor + episode_starts: th.Tensor class RecurrentDictRolloutBufferSamples(RecurrentRolloutBufferSamples): @@ -27,7 +27,7 @@ class RecurrentDictRolloutBufferSamples(RecurrentRolloutBufferSamples): advantages: th.Tensor returns: th.Tensor lstm_states: Tuple[th.Tensor, th.Tensor] - dones: th.Tensor + episode_starts: th.Tensor class RecurrentRolloutBuffer(RolloutBuffer): @@ -57,23 +57,24 @@ def __init__( ): self.lstm_states = lstm_states self.dones = None + self.initial_lstm_states = None super().__init__(buffer_size, observation_space, action_space, device, gae_lambda, gamma, n_envs) def reset(self): - self.hidden_states = np.zeros_like(self.lstm_states[0]) - self.cell_states = np.zeros_like(self.lstm_states[1]) - self.dones = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) super().reset() - - def add(self, *args, lstm_states: Tuple[np.ndarray, np.ndarray], dones: np.ndarray, **kwargs) -> None: - """ - :param hidden_states: LSTM cell and hidden state - """ - self.hidden_states[self.pos] = np.array(lstm_states[0]) - self.cell_states[self.pos] = np.array(lstm_states[1]) - self.dones[self.pos] = np.array(dones) - - super().add(*args, **kwargs) + # self.hidden_states = np.zeros_like(self.lstm_states[0]) + # self.cell_states = np.zeros_like(self.lstm_states[1]) + # self.dones = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) + + # def add(self, *args, lstm_states: Tuple[np.ndarray, np.ndarray], **kwargs) -> None: + # """ + # :param hidden_states: LSTM cell and hidden state + # """ + # self.hidden_states[self.pos] = np.array(lstm_states[0]) + # self.cell_states[self.pos] = np.array(lstm_states[1]) + # self.dones[self.pos] = np.array(dones) + # + # super().add(*args, **kwargs) def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentRolloutBufferSamples, None, None]: assert self.full, "" @@ -94,7 +95,7 @@ def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentRolloutBuf "returns", # "hidden_states", # "cell_states", - "dones", + "episode_starts", ]: self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor]) self.generator_ready = True @@ -110,23 +111,29 @@ def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentRolloutBuf # start_idx += batch_size # Do not shuffle the sequence, only the env indices - n_minibatches = (self.buffer_size * self.n_envs) // batch_size - assert ( - self.n_envs % n_minibatches == 0 - ), f"{self.n_envs} not a multiple of {n_minibatches} = {self.buffer_size * self.n_envs} // {batch_size}" - - # n_envs_per_batch = self.n_envs // n_minibatches - n_envs_per_batch = batch_size // self.buffer_size - env_indices = np.random.permutation(self.n_envs) + # n_minibatches = (self.buffer_size * self.n_envs) // batch_size + n_minibatches = 1 + # assert ( + # self.n_envs % n_minibatches == 0 + # ), f"{self.n_envs} not a multiple of {n_minibatches} = {self.buffer_size * self.n_envs} // {batch_size}" + n_envs_per_batch = self.n_envs // n_minibatches + # n_envs_per_batch = batch_size // self.buffer_size + + # env_indices = np.random.permutation(self.n_envs) + env_indices = np.arange(self.n_envs) flat_indices = np.arange(self.buffer_size * self.n_envs).reshape(self.n_envs, self.buffer_size) for start_env_idx in range(0, self.n_envs, n_envs_per_batch): end_env_idx = start_env_idx + n_envs_per_batch mini_batch_env_indices = env_indices[start_env_idx:end_env_idx] batch_inds = flat_indices[mini_batch_env_indices].ravel() + # lstm_states = ( + # self.hidden_states[:, :, mini_batch_env_indices, :][0], + # self.cell_states[:, :, mini_batch_env_indices, :][0], + # ) lstm_states = ( - self.hidden_states[:, :, mini_batch_env_indices, :][0], - self.cell_states[:, :, mini_batch_env_indices, :][0], + self.initial_lstm_states[0][:, mini_batch_env_indices].clone(), + self.initial_lstm_states[1][:, mini_batch_env_indices].clone(), ) yield RecurrentRolloutBufferSamples( @@ -136,22 +143,12 @@ def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentRolloutBuf old_log_prob=self.to_torch(self.log_probs[batch_inds].flatten()), advantages=self.to_torch(self.advantages[batch_inds].flatten()), returns=self.to_torch(self.returns[batch_inds].flatten()), - lstm_states=(self.to_torch(lstm_states[0]), self.to_torch(lstm_states[1])), - dones=self.to_torch(self.dones[batch_inds]), + # lstm_states=(self.to_torch(lstm_states[0]), self.to_torch(lstm_states[1])), + lstm_states=lstm_states, + # dones=self.to_torch(self.dones[batch_inds].flatten()), + episode_starts=self.to_torch(self.episode_starts[batch_inds].flatten()), ) - def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> RecurrentRolloutBufferSamples: - return RecurrentRolloutBufferSamples( - observations=self.to_torch(self.observations[batch_inds]), - actions=self.to_torch(self.actions[batch_inds]), - old_values=self.to_torch(self.values[batch_inds].flatten()), - old_log_prob=self.to_torch(self.log_probs[batch_inds].flatten()), - advantages=self.to_torch(self.advantages[batch_inds].flatten()), - returns=self.to_torch(self.returns[batch_inds].flatten()), - lstm_states=(self.to_torch(self.hidden_states[batch_inds][0:1]), self.to_torch(self.cell_states[batch_inds][0:1])), - dones=self.to_torch(self.dones[batch_inds]), - ) - class RecurrentDictRolloutBuffer(DictRolloutBuffer): """ diff --git a/sb3_contrib/common/recurrent/policies.py b/sb3_contrib/common/recurrent/policies.py index fd57ee54..e19b2a93 100644 --- a/sb3_contrib/common/recurrent/policies.py +++ b/sb3_contrib/common/recurrent/policies.py @@ -3,7 +3,9 @@ import gym import numpy as np import torch as th +from stable_baselines3.common.distributions import Distribution from stable_baselines3.common.policies import ActorCriticPolicy +from stable_baselines3.common.preprocessing import preprocess_obs from stable_baselines3.common.torch_layers import BaseFeaturesExtractor from stable_baselines3.common.type_aliases import Schedule from torch import nn @@ -87,11 +89,108 @@ def __init__( optimizer_kwargs, ) + def extract_features( + self, + obs: th.Tensor, + lstm_states: Tuple[th.Tensor, th.Tensor], + episode_starts: th.Tensor, + ) -> th.Tensor: + """ + Preprocess the observation if needed and extract features. + + :param obs: + :return: + """ + assert self.features_extractor is not None, "No features extractor was set" + preprocessed_obs = preprocess_obs(obs, self.observation_space, normalize_images=self.normalize_images) + return self.features_extractor(preprocessed_obs, lstm_states, episode_starts) + + def forward( + self, + obs: th.Tensor, + lstm_states: Tuple[th.Tensor, th.Tensor], + episode_starts: th.Tensor, + deterministic: bool = False, + ) -> Tuple[th.Tensor, th.Tensor, th.Tensor]: + """ + Forward pass in all the networks (actor and critic) + + :param obs: Observation + :param deterministic: Whether to sample or use deterministic actions + :return: action, value and log probability of the action + """ + # Preprocess the observation if needed + features = self.extract_features(obs, lstm_states, episode_starts) + latent_pi, latent_vf = self.mlp_extractor(features) + # Evaluate the values for the given observations + values = self.value_net(latent_vf) + distribution = self._get_action_dist_from_latent(latent_pi) + actions = distribution.get_actions(deterministic=deterministic) + log_prob = distribution.log_prob(actions) + return actions, values, log_prob + + def get_distribution( + self, + obs: th.Tensor, + lstm_states: Tuple[th.Tensor, th.Tensor], + episode_starts: th.Tensor, + ) -> Distribution: + """ + Get the current policy distribution given the observations. + + :param obs: + :return: the action distribution. + """ + features = self.extract_features(obs, lstm_states, episode_starts) + latent_pi = self.mlp_extractor.forward_actor(features) + return self._get_action_dist_from_latent(latent_pi) + + def predict_values( + self, + obs: th.Tensor, + lstm_states: Tuple[th.Tensor, th.Tensor], + episode_starts: th.Tensor, + ) -> th.Tensor: + """ + Get the estimated values according to the current policy given the observations. + + :param obs: + :return: the estimated values. + """ + features = self.extract_features(obs, lstm_states, episode_starts) + latent_vf = self.mlp_extractor.forward_critic(features) + return self.value_net(latent_vf) + + def evaluate_actions( + self, + obs: th.Tensor, + actions: th.Tensor, + lstm_states: Tuple[th.Tensor, th.Tensor], + episode_starts: th.Tensor, + ) -> Tuple[th.Tensor, th.Tensor, th.Tensor]: + """ + Evaluate actions according to the current policy, + given the observations. + + :param obs: + :param actions: + :return: estimated value, log likelihood of taking those actions + and entropy of the action distribution. + """ + # Preprocess the observation if needed + features = self.extract_features(obs, lstm_states, episode_starts) + latent_pi, latent_vf = self.mlp_extractor(features) + distribution = self._get_action_dist_from_latent(latent_pi) + log_prob = distribution.log_prob(actions) + values = self.value_net(latent_vf) + return values, log_prob, distribution.entropy() + def _predict( self, observation: th.Tensor, + lstm_states: Tuple[th.Tensor, th.Tensor], + episode_starts: th.Tensor, deterministic: bool = False, - lstm_states: Optional[Tuple[th.Tensor, th.Tensor]] = None, ) -> th.Tensor: """ Get the action according to the policy for a given observation. @@ -101,18 +200,12 @@ def _predict( :param deterministic: Whether to use stochastic or deterministic actions :return: Taken action according to the policy """ - self.features_extractor.set_lstm_states(lstm_states) - return self.get_distribution(observation).get_actions(deterministic=deterministic) + # self.features_extractor.set_lstm_states(lstm_states) + return self.get_distribution(observation, lstm_states, episode_starts).get_actions(deterministic=deterministic) def get_lstm_states(self) -> Tuple[th.Tensor, th.Tensor]: return self.features_extractor.lstm_states - def set_lstm_states(self, lstm_states: Optional[Tuple[th.Tensor]] = None) -> None: - self.features_extractor.set_lstm_states(lstm_states) - - def set_dones(self, dones: th.Tensor) -> None: - self.features_extractor.set_dones(dones) - def predict( self, observation: Union[np.ndarray, Dict[str, np.ndarray]], diff --git a/sb3_contrib/common/recurrent/torch_layers.py b/sb3_contrib/common/recurrent/torch_layers.py index 68274560..3bc19500 100644 --- a/sb3_contrib/common/recurrent/torch_layers.py +++ b/sb3_contrib/common/recurrent/torch_layers.py @@ -21,23 +21,14 @@ def __init__(self, observation_space: gym.Space, hidden_size: int = 64, num_laye super().__init__(observation_space, hidden_size) self.flatten = nn.Flatten() self.lstm = nn.LSTM(get_flattened_obs_dim(observation_space), hidden_size, num_layers=num_layers) - # One forward pass to initial hidden state - # dummy_cell_state, dummy_hidden = self.lstm() - # Cell and hidden state - n_envs = 1 - self.initial_hidden_state = (th.zeros(num_layers, n_envs, hidden_size), th.zeros(num_layers, n_envs, hidden_size)) - self._lstm_states = deepcopy(self.initial_hidden_state) - self.dones = None - - def reset_state(self) -> None: - self._lstm_states = deepcopy(self.initial_hidden_state) - self.dones = None + self._lstm_states = None + self.debug = False def process_sequence( self, observations: th.Tensor, lstm_states: Tuple[th.Tensor, th.Tensor], - dones: th.Tensor, + episode_starts: th.Tensor, ) -> Tuple[th.Tensor, th.Tensor]: features = self.flatten(observations) @@ -46,38 +37,37 @@ def process_sequence( n_envs = lstm_states[0].shape[1] # Note: order matters and should be consistent with the one from the buffer # above is when envs are interleaved + # Batch to sequence features_sequence = features.reshape((n_envs, -1, self.lstm.input_size)).swapaxes(0, 1) - dones = dones.reshape((n_envs, -1)).swapaxes(0, 1) + episode_starts = episode_starts.reshape((n_envs, -1)).swapaxes(0, 1) + # if self.debug: + # import ipdb; ipdb.set_trace() lstm_output = [] # Iterate over the sequence - for features, done in zip_strict(features_sequence, dones): + for features, episode_start in zip_strict(features_sequence, episode_starts): hidden, lstm_states = self.lstm( features.unsqueeze(dim=0), ( - (1.0 - done).view(1, -1, 1) * lstm_states[0], - (1.0 - done).view(1, -1, 1) * lstm_states[1], + (1.0 - episode_start).view(1, -1, 1) * lstm_states[0], + (1.0 - episode_start).view(1, -1, 1) * lstm_states[1], ), ) + # if self.debug: + # import ipdb; ipdb.set_trace() lstm_output += [hidden] + # Sequence to batch lstm_output = th.flatten(th.cat(lstm_output).transpose(0, 1), start_dim=0, end_dim=1) return lstm_output, lstm_states - def set_lstm_states(self, lstm_states: Optional[Tuple[th.Tensor]] = None) -> None: - if lstm_states is None: - self.reset_state() - else: - self._lstm_states = deepcopy(lstm_states) - - def set_dones(self, dones: th.Tensor) -> None: - self.dones = dones - @property def lstm_states(self) -> Tuple[th.Tensor, th.Tensor]: return self._lstm_states - def forward(self, observations: th.Tensor) -> th.Tensor: - if self.dones is None: - self.dones = th.zeros(len(observations)).float().to(observations.device) - features, self._lstm_states = self.process_sequence(observations, self._lstm_states, self.dones) - self.dones = None + def forward( + self, + observations: th.Tensor, + lstm_states: Tuple[th.Tensor, th.Tensor], + episode_starts: th.Tensor, + ) -> th.Tensor: + features, self._lstm_states = self.process_sequence(observations, lstm_states, episode_starts) return features diff --git a/sb3_contrib/ppo_lstm/ppo_lstm.py b/sb3_contrib/ppo_lstm/ppo_lstm.py index 8ea13adc..460885c4 100644 --- a/sb3_contrib/ppo_lstm/ppo_lstm.py +++ b/sb3_contrib/ppo_lstm/ppo_lstm.py @@ -246,6 +246,8 @@ def collect_rollouts( callback.on_rollout_start() + rollout_buffer.initial_lstm_states = self.lstm_states[0].clone(), self.lstm_states[1].clone() + while n_steps < n_rollout_steps: if self.use_sde and self.sde_sample_freq > 0 and n_steps % self.sde_sample_freq == 0: # Sample a new noise matrix @@ -254,9 +256,10 @@ def collect_rollouts( with th.no_grad(): # Convert to pytorch tensor or to TensorDict obs_tensor = obs_as_tensor(self._last_obs, self.device) - self.policy.set_lstm_states(self.lstm_states) - actions, values, log_probs = self.policy.forward(obs_tensor) + episode_starts = th.tensor(self._last_episode_starts).float().to(self.device) + actions, values, log_probs = self.policy.forward(obs_tensor, self.lstm_states, episode_starts) lstm_states = self.policy.get_lstm_states() + self.lstm_states = lstm_states[0].clone(), lstm_states[1].clone() actions = actions.cpu().numpy() # Rescale and perform action @@ -283,18 +286,21 @@ def collect_rollouts( # Handle timeout by bootstraping with value function # see GitHub issue #633 - for idx, done_ in enumerate(dones): - if ( - done_ - and infos[idx].get("terminal_observation") is not None - and infos[idx].get("TimeLimit.truncated", False) - ): - terminal_obs = self.policy.obs_to_tensor(infos[idx]["terminal_observation"])[0] - with th.no_grad(): - terminal_lstm_state = lstm_states[0][:, idx : idx + 1, :], lstm_states[1][:, idx : idx + 1, :] - self.policy.set_lstm_states(terminal_lstm_state) - terminal_value = self.policy.predict_values(terminal_obs)[0] - rewards[idx] += self.gamma * terminal_value + # for idx, done_ in enumerate(dones): + # if ( + # done_ + # and infos[idx].get("terminal_observation") is not None + # and infos[idx].get("TimeLimit.truncated", False) + # ): + # terminal_obs = self.policy.obs_to_tensor(infos[idx]["terminal_observation"])[0] + # with th.no_grad(): + # terminal_lstm_state = ( + # self.lstm_states[0][:, idx : idx + 1, :], + # self.lstm_states[1][:, idx : idx + 1, :], + # ) + # episode_starts = th.tensor([False]).float().to(self.device) + # terminal_value = self.policy.predict_values(terminal_obs, terminal_lstm_state, episode_starts)[0] + # rewards[idx] += self.gamma * terminal_value rollout_buffer.add( self._last_obs, @@ -303,22 +309,18 @@ def collect_rollouts( self._last_episode_starts, values, log_probs, - lstm_states=(lstm_states[0].cpu().numpy(), lstm_states[1].cpu().numpy()), - dones=dones, + # lstm_states=(lstm_states[0].cpu().numpy(), lstm_states[1].cpu().numpy()), ) self._last_obs = new_obs self._last_episode_starts = dones - # Reset states if needed - for idx, done_ in enumerate(dones): - if done_: - lstm_states[0][:, idx, :] = 0.0 - lstm_states[1][:, idx, :] = 0.0 - self.lstm_states = lstm_states with th.no_grad(): # Compute value for the last timestep - values = self.policy.predict_values(obs_as_tensor(new_obs, self.device)) + # TODO: update the lstm states? + # TODO: check episode_starts + episode_starts = th.tensor(self._last_episode_starts).float().to(self.device) + values = self.policy.predict_values(obs_as_tensor(new_obs, self.device), self.lstm_states, episode_starts) rollout_buffer.compute_returns_and_advantage(last_values=values, dones=dones) @@ -345,6 +347,7 @@ def train(self) -> None: clip_fractions = [] continue_training = True + self.policy.features_extractor.debug = True # train for n_epochs epochs for epoch in range(self.n_epochs): @@ -360,12 +363,13 @@ def train(self) -> None: if self.use_sde: self.policy.reset_noise(self.batch_size) - self.policy.set_lstm_states(rollout_data.lstm_states) - self.policy.set_dones(rollout_data.dones) values, log_prob, entropy = self.policy.evaluate_actions( rollout_data.observations, actions, + rollout_data.lstm_states, + rollout_data.episode_starts, ) + # self.policy.features_extractor.debug = False values = values.flatten() # Normalize advantage From 0ddc3f61707d7a428094a6e4f59c4bddf4d46eff Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Thu, 25 Nov 2021 15:56:11 +0100 Subject: [PATCH 07/50] First working version (no shared lstm) --- sb3_contrib/common/recurrent/buffers.py | 3 +- sb3_contrib/common/recurrent/policies.py | 172 ++++++++++++++++--- sb3_contrib/common/recurrent/torch_layers.py | 20 +-- tests/test_lstm.py | 23 +-- 4 files changed, 166 insertions(+), 52 deletions(-) diff --git a/sb3_contrib/common/recurrent/buffers.py b/sb3_contrib/common/recurrent/buffers.py index 295ef6dc..9521bbdd 100644 --- a/sb3_contrib/common/recurrent/buffers.py +++ b/sb3_contrib/common/recurrent/buffers.py @@ -119,8 +119,7 @@ def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentRolloutBuf n_envs_per_batch = self.n_envs // n_minibatches # n_envs_per_batch = batch_size // self.buffer_size - # env_indices = np.random.permutation(self.n_envs) - env_indices = np.arange(self.n_envs) + env_indices = np.random.permutation(self.n_envs) flat_indices = np.arange(self.buffer_size * self.n_envs).reshape(self.n_envs, self.buffer_size) for start_env_idx in range(0, self.n_envs, n_envs_per_batch): diff --git a/sb3_contrib/common/recurrent/policies.py b/sb3_contrib/common/recurrent/policies.py index e19b2a93..2c9420c8 100644 --- a/sb3_contrib/common/recurrent/policies.py +++ b/sb3_contrib/common/recurrent/policies.py @@ -6,12 +6,11 @@ from stable_baselines3.common.distributions import Distribution from stable_baselines3.common.policies import ActorCriticPolicy from stable_baselines3.common.preprocessing import preprocess_obs -from stable_baselines3.common.torch_layers import BaseFeaturesExtractor +from stable_baselines3.common.torch_layers import BaseFeaturesExtractor, FlattenExtractor from stable_baselines3.common.type_aliases import Schedule +from stable_baselines3.common.utils import zip_strict from torch import nn -from sb3_contrib.common.recurrent.torch_layers import LSTMExtractor - # CombinedExtractor,; FlattenExtractor,; MlpExtractor,; NatureCNN,; create_mlp, @@ -63,7 +62,7 @@ def __init__( sde_net_arch: Optional[List[int]] = None, use_expln: bool = False, squash_output: bool = False, - features_extractor_class: Type[BaseFeaturesExtractor] = LSTMExtractor, + features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor, features_extractor_kwargs: Optional[Dict[str, Any]] = None, normalize_images: bool = True, optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, @@ -89,21 +88,48 @@ def __init__( optimizer_kwargs, ) - def extract_features( + num_layers = 1 + hidden_size = 64 + # hidden_size = self.features_dim + # TODO: adjust mlp extractor input shape + # and add lstm for value function + self.adapter = nn.Linear(hidden_size, self.features_dim) + self.lstm = nn.LSTM(self.features_dim, hidden_size, num_layers=num_layers) + # Setup optimizer with initial learning rate + self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs) + + def _process_sequence( self, - obs: th.Tensor, + features: th.Tensor, lstm_states: Tuple[th.Tensor, th.Tensor], episode_starts: th.Tensor, - ) -> th.Tensor: - """ - Preprocess the observation if needed and extract features. + ) -> Tuple[th.Tensor, th.Tensor]: + lstm_states = lstm_states[0].clone(), lstm_states[1].clone() + episode_starts = episode_starts.clone() + # lstm_states = lstm_states[0].clone() * 0.0, lstm_states[1].clone() * 0.0 + # LSTM logic + # (sequence length, n_envs, features dim) (batch size = n envs) + n_envs = lstm_states[0].shape[1] + # Note: order matters and should be consistent with the one from the buffer + # above is when envs are interleaved + # Batch to sequence + features_sequence = features.reshape((n_envs, -1, self.lstm.input_size)).swapaxes(0, 1) + episode_starts = episode_starts.reshape((n_envs, -1)).swapaxes(0, 1) - :param obs: - :return: - """ - assert self.features_extractor is not None, "No features extractor was set" - preprocessed_obs = preprocess_obs(obs, self.observation_space, normalize_images=self.normalize_images) - return self.features_extractor(preprocessed_obs, lstm_states, episode_starts) + lstm_output = [] + # Iterate over the sequence + for features, episode_start in zip_strict(features_sequence, episode_starts): + hidden, lstm_states = self.lstm( + features.unsqueeze(dim=0), + ( + (1.0 - episode_start).view(1, n_envs, 1) * lstm_states[0], + (1.0 - episode_start).view(1, n_envs, 1) * lstm_states[1], + ), + ) + lstm_output += [self.adapter(hidden)] + # Sequence to batch + lstm_output = th.flatten(th.cat(lstm_output).transpose(0, 1), start_dim=0, end_dim=1) + return lstm_output, lstm_states def forward( self, @@ -120,14 +146,19 @@ def forward( :return: action, value and log probability of the action """ # Preprocess the observation if needed - features = self.extract_features(obs, lstm_states, episode_starts) - latent_pi, latent_vf = self.mlp_extractor(features) + features = self.extract_features(obs) + # latent_pi, latent_vf = self.mlp_extractor(features) + latent_pi, lstm_states = self._process_sequence(features, lstm_states, episode_starts) + + latent_pi = self.mlp_extractor.forward_actor(latent_pi) + latent_vf = self.mlp_extractor.forward_critic(features) + # Evaluate the values for the given observations values = self.value_net(latent_vf) distribution = self._get_action_dist_from_latent(latent_pi) actions = distribution.get_actions(deterministic=deterministic) log_prob = distribution.log_prob(actions) - return actions, values, log_prob + return actions, values, log_prob, lstm_states def get_distribution( self, @@ -141,8 +172,9 @@ def get_distribution( :param obs: :return: the action distribution. """ - features = self.extract_features(obs, lstm_states, episode_starts) - latent_pi = self.mlp_extractor.forward_actor(features) + features = self.extract_features(obs) + latent_pi, _ = self._process_sequence(features, lstm_states, episode_starts) + latent_pi = self.mlp_extractor.forward_actor(latent_pi) return self._get_action_dist_from_latent(latent_pi) def predict_values( @@ -157,7 +189,7 @@ def predict_values( :param obs: :return: the estimated values. """ - features = self.extract_features(obs, lstm_states, episode_starts) + features = self.extract_features(obs) latent_vf = self.mlp_extractor.forward_critic(features) return self.value_net(latent_vf) @@ -178,8 +210,12 @@ def evaluate_actions( and entropy of the action distribution. """ # Preprocess the observation if needed - features = self.extract_features(obs, lstm_states, episode_starts) - latent_pi, latent_vf = self.mlp_extractor(features) + features = self.extract_features(obs) + # latent_pi, latent_vf = self.mlp_extractor(features) + latent_pi, _ = self._process_sequence(features, lstm_states, episode_starts) + latent_pi = self.mlp_extractor.forward_actor(latent_pi) + latent_vf = self.mlp_extractor.forward_critic(features) + distribution = self._get_action_dist_from_latent(latent_pi) log_prob = distribution.log_prob(actions) values = self.value_net(latent_vf) @@ -200,12 +236,8 @@ def _predict( :param deterministic: Whether to use stochastic or deterministic actions :return: Taken action according to the policy """ - # self.features_extractor.set_lstm_states(lstm_states) return self.get_distribution(observation, lstm_states, episode_starts).get_actions(deterministic=deterministic) - def get_lstm_states(self) -> Tuple[th.Tensor, th.Tensor]: - return self.features_extractor.lstm_states - def predict( self, observation: Union[np.ndarray, Dict[str, np.ndarray]], @@ -262,3 +294,89 @@ def predict( actions = actions[0] return actions, states + + +import torch +from torch.distributions.categorical import Categorical + + +def layer_init(layer, std=np.sqrt(2), bias_const=0.0): + torch.nn.init.orthogonal_(layer.weight, std) + torch.nn.init.constant_(layer.bias, bias_const) + return layer + + +class Agent(nn.Module): + def __init__(self, envs): + super(Agent, self).__init__() + self.network = nn.Flatten() + + self.critic = nn.Sequential( + nn.Linear(np.array(envs.observation_space.shape).prod(), 64), + nn.Tanh(), + nn.Linear(64, 1), + ) + + self.actor = nn.Sequential( + nn.Linear(64, 64), + nn.Tanh(), + nn.Linear(64, envs.action_space.n), + ) + self.lstm = nn.LSTM(np.array(envs.observation_space.shape).prod(), 64) + # self.lstm_critic = nn.LSTM(np.array(envs.observation_space.shape).prod(), 64) + + def get_states_critic(self, x, lstm_state, done): + hidden = self.network(x) + + # LSTM logic + batch_size = lstm_state[0].shape[1] + hidden = hidden.reshape((batch_size, -1, self.lstm_critic.input_size)).swapaxes(0, 1) + done = done.reshape((batch_size, -1)).swapaxes(0, 1) + new_hidden = [] + for h, d in zip(hidden, done): + h, lstm_state = self.lstm_critic( + h.unsqueeze(0), + ( + (1.0 - d).view(1, -1, 1) * lstm_state[0], + (1.0 - d).view(1, -1, 1) * lstm_state[1], + ), + ) + new_hidden += [h] + new_hidden = torch.flatten(torch.cat(new_hidden).transpose(0, 1), 0, 1) + return new_hidden, lstm_state + + def get_states(self, x, lstm_state, done): + hidden = self.network(x) + + # LSTM logic + batch_size = lstm_state[0].shape[1] + hidden = hidden.reshape((batch_size, -1, self.lstm.input_size)).swapaxes(0, 1) + done = done.reshape((batch_size, -1)).swapaxes(0, 1) + new_hidden = [] + for h, d in zip(hidden, done): + h, lstm_state = self.lstm( + h.unsqueeze(0), + ( + (1.0 - d).view(1, -1, 1) * lstm_state[0], + (1.0 - d).view(1, -1, 1) * lstm_state[1], + ), + ) + new_hidden += [h] + new_hidden = torch.flatten(torch.cat(new_hidden).transpose(0, 1), 0, 1) + return new_hidden, lstm_state + + def get_value(self, x, lstm_state, done): + # hidden, _ = self.get_states_critic(x, (lstm_state[0] * 0.0, lstm_state[1] * 0.0), done) + return self.critic(x) + + def get_action_and_value(self, x, lstm_state, done, action=None): + hidden, lstm_state = self.get_states(x, lstm_state, done) + # hidden_critic, lstm_state_critic = self.get_states_critic(x, lstm_state, done) + logits = self.actor(hidden) + probs = Categorical(logits=logits) + if action is None: + action = probs.sample() + return action, probs.log_prob(action), probs.entropy(), self.critic(x), lstm_state + + def set_training_mode(self, x): + pass diff --git a/sb3_contrib/common/recurrent/torch_layers.py b/sb3_contrib/common/recurrent/torch_layers.py index 3bc19500..489aaf80 100644 --- a/sb3_contrib/common/recurrent/torch_layers.py +++ b/sb3_contrib/common/recurrent/torch_layers.py @@ -32,6 +32,9 @@ def process_sequence( ) -> Tuple[th.Tensor, th.Tensor]: features = self.flatten(observations) + lstm_states = lstm_states[0].clone(), lstm_states[1].clone() + episode_starts = episode_starts.clone() + # lstm_states = lstm_states[0].clone() * 0.0, lstm_states[1].clone() * 0.0 # LSTM logic # (sequence length, n_envs, features dim) (batch size = n envs) n_envs = lstm_states[0].shape[1] @@ -40,34 +43,27 @@ def process_sequence( # Batch to sequence features_sequence = features.reshape((n_envs, -1, self.lstm.input_size)).swapaxes(0, 1) episode_starts = episode_starts.reshape((n_envs, -1)).swapaxes(0, 1) - # if self.debug: - # import ipdb; ipdb.set_trace() + lstm_output = [] # Iterate over the sequence for features, episode_start in zip_strict(features_sequence, episode_starts): hidden, lstm_states = self.lstm( features.unsqueeze(dim=0), ( - (1.0 - episode_start).view(1, -1, 1) * lstm_states[0], - (1.0 - episode_start).view(1, -1, 1) * lstm_states[1], + (1.0 - episode_start).view(1, n_envs, 1) * lstm_states[0], + (1.0 - episode_start).view(1, n_envs, 1) * lstm_states[1], ), ) - # if self.debug: - # import ipdb; ipdb.set_trace() lstm_output += [hidden] # Sequence to batch lstm_output = th.flatten(th.cat(lstm_output).transpose(0, 1), start_dim=0, end_dim=1) return lstm_output, lstm_states - @property - def lstm_states(self) -> Tuple[th.Tensor, th.Tensor]: - return self._lstm_states - def forward( self, observations: th.Tensor, lstm_states: Tuple[th.Tensor, th.Tensor], episode_starts: th.Tensor, ) -> th.Tensor: - features, self._lstm_states = self.process_sequence(observations, lstm_states, episode_starts) - return features + features, lstm_states = self.process_sequence(observations, lstm_states, episode_starts) + return features, lstm_states diff --git a/tests/test_lstm.py b/tests/test_lstm.py index d4e85d62..7cc9484a 100644 --- a/tests/test_lstm.py +++ b/tests/test_lstm.py @@ -36,29 +36,30 @@ def step(self, action): def test_ppo_lstm(): from stable_baselines3.common.env_util import make_vec_env - env = make_vec_env("CartPole-v1", n_envs=16) + # env = make_vec_env("CartPole-v1", n_envs=8) def make_env(): env = CartPoleNoVelEnv() env = TimeLimit(env, max_episode_steps=500) return env - env = make_vec_env(make_env, n_envs=16) - # env = CartPoleNoVelEnv() - # import gym - # env = gym.make("CartPole-v1") - # env = TimeLimit(env, max_episode_steps=500) + env = make_vec_env(make_env, n_envs=8) model = RecurrentPPO( "MlpLstmPolicy", env, - n_steps=128, - learning_rate=3e-4, + n_steps=8, + learning_rate=2.5e-4, verbose=1, - batch_size=512, + batch_size=64, seed=0, + gae_lambda=0.95, + policy_kwargs=dict(ortho_init=False), + # ent_coef=0.01, + # policy_kwargs=dict(net_arch=[dict(pi=[64], vf=[64])]) # create_eval_env=True, ) + # model.learn(total_timesteps=500, eval_freq=250) - # model.learn(total_timesteps=100_000) - model.learn(total_timesteps=100) + model.learn(total_timesteps=1_000_000) + # model.learn(total_timesteps=100) From 5ef313b5f86b0948bfcd80209796d16a506dbe84 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Thu, 25 Nov 2021 16:04:17 +0100 Subject: [PATCH 08/50] Start cleanup --- sb3_contrib/common/recurrent/buffers.py | 6 +- sb3_contrib/common/recurrent/policies.py | 91 +------------------- sb3_contrib/common/recurrent/torch_layers.py | 69 --------------- sb3_contrib/common/recurrent/utils.py | 0 sb3_contrib/ppo_lstm/ppo_lstm.py | 60 ++++++------- tests/test_lstm.py | 6 +- 6 files changed, 39 insertions(+), 193 deletions(-) delete mode 100644 sb3_contrib/common/recurrent/torch_layers.py delete mode 100644 sb3_contrib/common/recurrent/utils.py diff --git a/sb3_contrib/common/recurrent/buffers.py b/sb3_contrib/common/recurrent/buffers.py index 9521bbdd..8c4f9160 100644 --- a/sb3_contrib/common/recurrent/buffers.py +++ b/sb3_contrib/common/recurrent/buffers.py @@ -49,14 +49,14 @@ def __init__( buffer_size: int, observation_space: spaces.Space, action_space: spaces.Space, - lstm_states: Tuple[np.ndarray, np.ndarray], + # lstm_states: Tuple[np.ndarray, np.ndarray], device: Union[th.device, str] = "cpu", gae_lambda: float = 1, gamma: float = 0.99, n_envs: int = 1, ): - self.lstm_states = lstm_states - self.dones = None + # self.lstm_states = lstm_states + # self.dones = None self.initial_lstm_states = None super().__init__(buffer_size, observation_space, action_space, device, gae_lambda, gamma, n_envs) diff --git a/sb3_contrib/common/recurrent/policies.py b/sb3_contrib/common/recurrent/policies.py index 2c9420c8..c3a2d08f 100644 --- a/sb3_contrib/common/recurrent/policies.py +++ b/sb3_contrib/common/recurrent/policies.py @@ -104,9 +104,8 @@ def _process_sequence( lstm_states: Tuple[th.Tensor, th.Tensor], episode_starts: th.Tensor, ) -> Tuple[th.Tensor, th.Tensor]: - lstm_states = lstm_states[0].clone(), lstm_states[1].clone() - episode_starts = episode_starts.clone() - # lstm_states = lstm_states[0].clone() * 0.0, lstm_states[1].clone() * 0.0 + # lstm_states = lstm_states[0].clone(), lstm_states[1].clone() + # episode_starts = episode_starts.clone() # LSTM logic # (sequence length, n_envs, features dim) (batch size = n envs) n_envs = lstm_states[0].shape[1] @@ -294,89 +293,3 @@ def predict( actions = actions[0] return actions, states - - -import torch -from torch.distributions.categorical import Categorical - - -def layer_init(layer, std=np.sqrt(2), bias_const=0.0): - torch.nn.init.orthogonal_(layer.weight, std) - torch.nn.init.constant_(layer.bias, bias_const) - return layer - - -class Agent(nn.Module): - def __init__(self, envs): - super(Agent, self).__init__() - self.network = nn.Flatten() - - self.critic = nn.Sequential( - nn.Linear(np.array(envs.observation_space.shape).prod(), 64), - nn.Tanh(), - nn.Linear(64, 1), - ) - - self.actor = nn.Sequential( - nn.Linear(64, 64), - nn.Tanh(), - nn.Linear(64, envs.action_space.n), - ) - self.lstm = nn.LSTM(np.array(envs.observation_space.shape).prod(), 64) - # self.lstm_critic = nn.LSTM(np.array(envs.observation_space.shape).prod(), 64) - - def get_states_critic(self, x, lstm_state, done): - hidden = self.network(x) - - # LSTM logic - batch_size = lstm_state[0].shape[1] - hidden = hidden.reshape((batch_size, -1, self.lstm_critic.input_size)).swapaxes(0, 1) - done = done.reshape((batch_size, -1)).swapaxes(0, 1) - new_hidden = [] - for h, d in zip(hidden, done): - h, lstm_state = self.lstm_critic( - h.unsqueeze(0), - ( - (1.0 - d).view(1, -1, 1) * lstm_state[0], - (1.0 - d).view(1, -1, 1) * lstm_state[1], - ), - ) - new_hidden += [h] - new_hidden = torch.flatten(torch.cat(new_hidden).transpose(0, 1), 0, 1) - return new_hidden, lstm_state - - def get_states(self, x, lstm_state, done): - hidden = self.network(x) - - # LSTM logic - batch_size = lstm_state[0].shape[1] - hidden = hidden.reshape((batch_size, -1, self.lstm.input_size)).swapaxes(0, 1) - done = done.reshape((batch_size, -1)).swapaxes(0, 1) - new_hidden = [] - for h, d in zip(hidden, done): - h, lstm_state = self.lstm( - h.unsqueeze(0), - ( - (1.0 - d).view(1, -1, 1) * lstm_state[0], - (1.0 - d).view(1, -1, 1) * lstm_state[1], - ), - ) - new_hidden += [h] - new_hidden = torch.flatten(torch.cat(new_hidden).transpose(0, 1), 0, 1) - return new_hidden, lstm_state - - def get_value(self, x, lstm_state, done): - # hidden, _ = self.get_states_critic(x, (lstm_state[0] * 0.0, lstm_state[1] * 0.0), done) - return self.critic(x) - - def get_action_and_value(self, x, lstm_state, done, action=None): - hidden, lstm_state = self.get_states(x, lstm_state, done) - # hidden_critic, lstm_state_critic = self.get_states_critic(x, lstm_state, done) - logits = self.actor(hidden) - probs = Categorical(logits=logits) - if action is None: - action = probs.sample() - return action, probs.log_prob(action), probs.entropy(), self.critic(x), lstm_state - - def set_training_mode(self, x): - pass diff --git a/sb3_contrib/common/recurrent/torch_layers.py b/sb3_contrib/common/recurrent/torch_layers.py deleted file mode 100644 index 489aaf80..00000000 --- a/sb3_contrib/common/recurrent/torch_layers.py +++ /dev/null @@ -1,69 +0,0 @@ -from copy import deepcopy -from typing import Optional, Tuple - -import gym -import torch as th -from stable_baselines3.common.preprocessing import get_flattened_obs_dim -from stable_baselines3.common.torch_layers import BaseFeaturesExtractor -from stable_baselines3.common.utils import zip_strict -from torch import nn - - -class LSTMExtractor(BaseFeaturesExtractor): - """ - Feature extract that pass the data through an LSTM after flattening it. - Used as a placeholder when feature extraction is not needed. - - :param observation_space: - """ - - def __init__(self, observation_space: gym.Space, hidden_size: int = 64, num_layers: int = 1): - super().__init__(observation_space, hidden_size) - self.flatten = nn.Flatten() - self.lstm = nn.LSTM(get_flattened_obs_dim(observation_space), hidden_size, num_layers=num_layers) - self._lstm_states = None - self.debug = False - - def process_sequence( - self, - observations: th.Tensor, - lstm_states: Tuple[th.Tensor, th.Tensor], - episode_starts: th.Tensor, - ) -> Tuple[th.Tensor, th.Tensor]: - features = self.flatten(observations) - - lstm_states = lstm_states[0].clone(), lstm_states[1].clone() - episode_starts = episode_starts.clone() - # lstm_states = lstm_states[0].clone() * 0.0, lstm_states[1].clone() * 0.0 - # LSTM logic - # (sequence length, n_envs, features dim) (batch size = n envs) - n_envs = lstm_states[0].shape[1] - # Note: order matters and should be consistent with the one from the buffer - # above is when envs are interleaved - # Batch to sequence - features_sequence = features.reshape((n_envs, -1, self.lstm.input_size)).swapaxes(0, 1) - episode_starts = episode_starts.reshape((n_envs, -1)).swapaxes(0, 1) - - lstm_output = [] - # Iterate over the sequence - for features, episode_start in zip_strict(features_sequence, episode_starts): - hidden, lstm_states = self.lstm( - features.unsqueeze(dim=0), - ( - (1.0 - episode_start).view(1, n_envs, 1) * lstm_states[0], - (1.0 - episode_start).view(1, n_envs, 1) * lstm_states[1], - ), - ) - lstm_output += [hidden] - # Sequence to batch - lstm_output = th.flatten(th.cat(lstm_output).transpose(0, 1), start_dim=0, end_dim=1) - return lstm_output, lstm_states - - def forward( - self, - observations: th.Tensor, - lstm_states: Tuple[th.Tensor, th.Tensor], - episode_starts: th.Tensor, - ) -> th.Tensor: - features, lstm_states = self.process_sequence(observations, lstm_states, episode_starts) - return features, lstm_states diff --git a/sb3_contrib/common/recurrent/utils.py b/sb3_contrib/common/recurrent/utils.py deleted file mode 100644 index e69de29b..00000000 diff --git a/sb3_contrib/ppo_lstm/ppo_lstm.py b/sb3_contrib/ppo_lstm/ppo_lstm.py index 460885c4..fc885ed8 100644 --- a/sb3_contrib/ppo_lstm/ppo_lstm.py +++ b/sb3_contrib/ppo_lstm/ppo_lstm.py @@ -1,4 +1,5 @@ import time +from copy import deepcopy from typing import Any, Dict, Optional, Tuple, Type, Union import gym @@ -142,12 +143,14 @@ def _setup_model(self) -> None: ) self.policy = self.policy.to(self.device) + lstm = self.policy.lstm + if not isinstance(self.policy, RecurrentActorCriticPolicy): raise ValueError("Policy must subclass RecurrentActorCriticPolicy") - lstm = self.policy.features_extractor.lstm - hidden_state_shape = (self.n_steps, lstm.num_layers, self.n_envs, lstm.hidden_size) - lstm_states = (np.zeros(hidden_state_shape, dtype=np.float32), np.zeros(hidden_state_shape, dtype=np.float32)) + # hidden_state_shape = (self.n_steps, lstm.num_layers, self.n_envs, lstm.hidden_size) + # lstm_states = (np.zeros(hidden_state_shape, dtype=np.float32), np.zeros(hidden_state_shape, dtype=np.float32)) + single_hidden_state_shape = (lstm.num_layers, self.n_envs, lstm.hidden_size) self.lstm_states = ( th.zeros(single_hidden_state_shape).to(self.device), @@ -158,7 +161,7 @@ def _setup_model(self) -> None: self.n_steps, self.observation_space, self.action_space, - lstm_states, + # lstm_states, self.device, gamma=self.gamma, gae_lambda=self.gae_lambda, @@ -246,7 +249,8 @@ def collect_rollouts( callback.on_rollout_start() - rollout_buffer.initial_lstm_states = self.lstm_states[0].clone(), self.lstm_states[1].clone() + rollout_buffer.initial_lstm_states = deepcopy(self.lstm_states) + lstm_states = deepcopy(self.lstm_states) while n_steps < n_rollout_steps: if self.use_sde and self.sde_sample_freq > 0 and n_steps % self.sde_sample_freq == 0: @@ -257,9 +261,8 @@ def collect_rollouts( # Convert to pytorch tensor or to TensorDict obs_tensor = obs_as_tensor(self._last_obs, self.device) episode_starts = th.tensor(self._last_episode_starts).float().to(self.device) - actions, values, log_probs = self.policy.forward(obs_tensor, self.lstm_states, episode_starts) - lstm_states = self.policy.get_lstm_states() - self.lstm_states = lstm_states[0].clone(), lstm_states[1].clone() + actions, values, log_probs, lstm_states = self.policy.forward(obs_tensor, lstm_states, episode_starts) + actions = actions.cpu().numpy() # Rescale and perform action @@ -286,21 +289,21 @@ def collect_rollouts( # Handle timeout by bootstraping with value function # see GitHub issue #633 - # for idx, done_ in enumerate(dones): - # if ( - # done_ - # and infos[idx].get("terminal_observation") is not None - # and infos[idx].get("TimeLimit.truncated", False) - # ): - # terminal_obs = self.policy.obs_to_tensor(infos[idx]["terminal_observation"])[0] - # with th.no_grad(): - # terminal_lstm_state = ( - # self.lstm_states[0][:, idx : idx + 1, :], - # self.lstm_states[1][:, idx : idx + 1, :], - # ) - # episode_starts = th.tensor([False]).float().to(self.device) - # terminal_value = self.policy.predict_values(terminal_obs, terminal_lstm_state, episode_starts)[0] - # rewards[idx] += self.gamma * terminal_value + for idx, done_ in enumerate(dones): + if ( + done_ + and infos[idx].get("terminal_observation") is not None + and infos[idx].get("TimeLimit.truncated", False) + ): + terminal_obs = self.policy.obs_to_tensor(infos[idx]["terminal_observation"])[0] + with th.no_grad(): + terminal_lstm_state = ( + lstm_states[0][:, idx : idx + 1, :], + lstm_states[1][:, idx : idx + 1, :], + ) + episode_starts = th.tensor([False]).float().to(self.device) + terminal_value = self.policy.predict_values(terminal_obs, terminal_lstm_state, episode_starts)[0] + rewards[idx] += self.gamma * terminal_value rollout_buffer.add( self._last_obs, @@ -315,12 +318,12 @@ def collect_rollouts( self._last_obs = new_obs self._last_episode_starts = dones + self.lstm_states = deepcopy(lstm_states) + with th.no_grad(): # Compute value for the last timestep - # TODO: update the lstm states? - # TODO: check episode_starts - episode_starts = th.tensor(self._last_episode_starts).float().to(self.device) - values = self.policy.predict_values(obs_as_tensor(new_obs, self.device), self.lstm_states, episode_starts) + episode_starts = th.tensor(dones).float().to(self.device) + values = self.policy.predict_values(obs_as_tensor(new_obs, self.device), lstm_states, episode_starts) rollout_buffer.compute_returns_and_advantage(last_values=values, dones=dones) @@ -347,7 +350,7 @@ def train(self) -> None: clip_fractions = [] continue_training = True - self.policy.features_extractor.debug = True + # self.policy.features_extractor.debug = True # train for n_epochs epochs for epoch in range(self.n_epochs): @@ -369,7 +372,6 @@ def train(self) -> None: rollout_data.lstm_states, rollout_data.episode_starts, ) - # self.policy.features_extractor.debug = False values = values.flatten() # Normalize advantage diff --git a/tests/test_lstm.py b/tests/test_lstm.py index 7cc9484a..9b9c3514 100644 --- a/tests/test_lstm.py +++ b/tests/test_lstm.py @@ -36,14 +36,14 @@ def step(self, action): def test_ppo_lstm(): from stable_baselines3.common.env_util import make_vec_env - # env = make_vec_env("CartPole-v1", n_envs=8) + env = make_vec_env("CartPole-v1", n_envs=8) def make_env(): env = CartPoleNoVelEnv() env = TimeLimit(env, max_episode_steps=500) return env - env = make_vec_env(make_env, n_envs=8) + # env = make_vec_env(make_env, n_envs=8) model = RecurrentPPO( "MlpLstmPolicy", @@ -61,5 +61,5 @@ def make_env(): ) # model.learn(total_timesteps=500, eval_freq=250) - model.learn(total_timesteps=1_000_000) + model.learn(total_timesteps=22_000) # model.learn(total_timesteps=100) From c803ac9923b5b1096c3192eb60dbc1b4baa1edd1 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Thu, 25 Nov 2021 16:50:39 +0100 Subject: [PATCH 09/50] Try rnn with value function --- sb3_contrib/common/recurrent/buffers.py | 13 +++-- sb3_contrib/common/recurrent/policies.py | 61 +++++++++++++++--------- sb3_contrib/ppo_lstm/ppo_lstm.py | 19 +++++--- 3 files changed, 61 insertions(+), 32 deletions(-) diff --git a/sb3_contrib/common/recurrent/buffers.py b/sb3_contrib/common/recurrent/buffers.py index 8c4f9160..723ba58e 100644 --- a/sb3_contrib/common/recurrent/buffers.py +++ b/sb3_contrib/common/recurrent/buffers.py @@ -130,10 +130,15 @@ def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentRolloutBuf # self.hidden_states[:, :, mini_batch_env_indices, :][0], # self.cell_states[:, :, mini_batch_env_indices, :][0], # ) - lstm_states = ( - self.initial_lstm_states[0][:, mini_batch_env_indices].clone(), - self.initial_lstm_states[1][:, mini_batch_env_indices].clone(), + lstm_states_pi = ( + self.initial_lstm_states[0][0][:, mini_batch_env_indices].clone(), + self.initial_lstm_states[0][1][:, mini_batch_env_indices].clone(), ) + # lstm_states_vf = ( + # self.initial_lstm_states[1][0][:, mini_batch_env_indices].clone(), + # self.initial_lstm_states[1][1][:, mini_batch_env_indices].clone(), + # ) + lstm_states_vf = None yield RecurrentRolloutBufferSamples( observations=self.to_torch(self.observations[batch_inds]), @@ -143,7 +148,7 @@ def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentRolloutBuf advantages=self.to_torch(self.advantages[batch_inds].flatten()), returns=self.to_torch(self.returns[batch_inds].flatten()), # lstm_states=(self.to_torch(lstm_states[0]), self.to_torch(lstm_states[1])), - lstm_states=lstm_states, + lstm_states=(lstm_states_pi, lstm_states_vf), # dones=self.to_torch(self.dones[batch_inds].flatten()), episode_starts=self.to_torch(self.episode_starts[batch_inds].flatten()), ) diff --git a/sb3_contrib/common/recurrent/policies.py b/sb3_contrib/common/recurrent/policies.py index c3a2d08f..ace1de62 100644 --- a/sb3_contrib/common/recurrent/policies.py +++ b/sb3_contrib/common/recurrent/policies.py @@ -6,7 +6,7 @@ from stable_baselines3.common.distributions import Distribution from stable_baselines3.common.policies import ActorCriticPolicy from stable_baselines3.common.preprocessing import preprocess_obs -from stable_baselines3.common.torch_layers import BaseFeaturesExtractor, FlattenExtractor +from stable_baselines3.common.torch_layers import BaseFeaturesExtractor, FlattenExtractor, MlpExtractor from stable_baselines3.common.type_aliases import Schedule from stable_baselines3.common.utils import zip_strict from torch import nn @@ -68,6 +68,8 @@ def __init__( optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, optimizer_kwargs: Optional[Dict[str, Any]] = None, ): + hidden_size = 64 + self.lstm_output_dim = hidden_size super().__init__( observation_space, action_space, @@ -89,43 +91,49 @@ def __init__( ) num_layers = 1 - hidden_size = 64 - # hidden_size = self.features_dim - # TODO: adjust mlp extractor input shape - # and add lstm for value function - self.adapter = nn.Linear(hidden_size, self.features_dim) - self.lstm = nn.LSTM(self.features_dim, hidden_size, num_layers=num_layers) + self.lstm_actor = nn.LSTM(self.features_dim, hidden_size, num_layers=num_layers) + # self.lstm_critic = nn.LSTM(self.features_dim, hidden_size, num_layers=num_layers) + self.critic = nn.Linear(self.features_dim, hidden_size) # Setup optimizer with initial learning rate self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs) + def _build_mlp_extractor(self) -> None: + """ + Create the policy and value networks. + Part of the layers can be shared. + """ + self.mlp_extractor = MlpExtractor( + self.lstm_output_dim, + net_arch=self.net_arch, + activation_fn=self.activation_fn, + device=self.device, + ) + + @staticmethod def _process_sequence( - self, features: th.Tensor, lstm_states: Tuple[th.Tensor, th.Tensor], episode_starts: th.Tensor, + lstm: nn.LSTM, ) -> Tuple[th.Tensor, th.Tensor]: - # lstm_states = lstm_states[0].clone(), lstm_states[1].clone() - # episode_starts = episode_starts.clone() # LSTM logic # (sequence length, n_envs, features dim) (batch size = n envs) n_envs = lstm_states[0].shape[1] - # Note: order matters and should be consistent with the one from the buffer - # above is when envs are interleaved # Batch to sequence - features_sequence = features.reshape((n_envs, -1, self.lstm.input_size)).swapaxes(0, 1) + features_sequence = features.reshape((n_envs, -1, lstm.input_size)).swapaxes(0, 1) episode_starts = episode_starts.reshape((n_envs, -1)).swapaxes(0, 1) lstm_output = [] # Iterate over the sequence for features, episode_start in zip_strict(features_sequence, episode_starts): - hidden, lstm_states = self.lstm( + hidden, lstm_states = lstm( features.unsqueeze(dim=0), ( (1.0 - episode_start).view(1, n_envs, 1) * lstm_states[0], (1.0 - episode_start).view(1, n_envs, 1) * lstm_states[1], ), ) - lstm_output += [self.adapter(hidden)] + lstm_output += [hidden] # Sequence to batch lstm_output = th.flatten(th.cat(lstm_output).transpose(0, 1), start_dim=0, end_dim=1) return lstm_output, lstm_states @@ -147,17 +155,20 @@ def forward( # Preprocess the observation if needed features = self.extract_features(obs) # latent_pi, latent_vf = self.mlp_extractor(features) - latent_pi, lstm_states = self._process_sequence(features, lstm_states, episode_starts) + latent_pi, lstm_states_pi = self._process_sequence(features, lstm_states[0], episode_starts, self.lstm_actor) + # latent_vf, lstm_states_vf = self._process_sequence(features, lstm_states[1], episode_starts, self.lstm_critic) + lstm_states_vf = None + latent_vf = self.critic(features) latent_pi = self.mlp_extractor.forward_actor(latent_pi) - latent_vf = self.mlp_extractor.forward_critic(features) + latent_vf = self.mlp_extractor.forward_critic(latent_vf) # Evaluate the values for the given observations values = self.value_net(latent_vf) distribution = self._get_action_dist_from_latent(latent_pi) actions = distribution.get_actions(deterministic=deterministic) log_prob = distribution.log_prob(actions) - return actions, values, log_prob, lstm_states + return actions, values, log_prob, (lstm_states_pi, lstm_states_vf) def get_distribution( self, @@ -172,7 +183,7 @@ def get_distribution( :return: the action distribution. """ features = self.extract_features(obs) - latent_pi, _ = self._process_sequence(features, lstm_states, episode_starts) + latent_pi, _ = self._process_sequence(features, lstm_states, episode_starts, self.lstm_actor) latent_pi = self.mlp_extractor.forward_actor(latent_pi) return self._get_action_dist_from_latent(latent_pi) @@ -189,7 +200,9 @@ def predict_values( :return: the estimated values. """ features = self.extract_features(obs) - latent_vf = self.mlp_extractor.forward_critic(features) + # latent_vf, _ = self._process_sequence(features, lstm_states, episode_starts, self.lstm_critic) + latent_vf = self.critic(features) + latent_vf = self.mlp_extractor.forward_critic(latent_vf) return self.value_net(latent_vf) def evaluate_actions( @@ -211,9 +224,13 @@ def evaluate_actions( # Preprocess the observation if needed features = self.extract_features(obs) # latent_pi, latent_vf = self.mlp_extractor(features) - latent_pi, _ = self._process_sequence(features, lstm_states, episode_starts) + latent_pi, _ = self._process_sequence(features, lstm_states[0], episode_starts, self.lstm_actor) + # latent_vf, _ = self._process_sequence(features, lstm_states[1], episode_starts, self.lstm_critic) + latent_vf = self.critic(features) + + latent_pi = self.mlp_extractor.forward_actor(latent_pi) - latent_vf = self.mlp_extractor.forward_critic(features) + latent_vf = self.mlp_extractor.forward_critic(latent_vf) distribution = self._get_action_dist_from_latent(latent_pi) log_prob = distribution.log_prob(actions) diff --git a/sb3_contrib/ppo_lstm/ppo_lstm.py b/sb3_contrib/ppo_lstm/ppo_lstm.py index fc885ed8..2a3e8e0d 100644 --- a/sb3_contrib/ppo_lstm/ppo_lstm.py +++ b/sb3_contrib/ppo_lstm/ppo_lstm.py @@ -143,7 +143,7 @@ def _setup_model(self) -> None: ) self.policy = self.policy.to(self.device) - lstm = self.policy.lstm + lstm = self.policy.lstm_actor if not isinstance(self.policy, RecurrentActorCriticPolicy): raise ValueError("Policy must subclass RecurrentActorCriticPolicy") @@ -152,9 +152,16 @@ def _setup_model(self) -> None: # lstm_states = (np.zeros(hidden_state_shape, dtype=np.float32), np.zeros(hidden_state_shape, dtype=np.float32)) single_hidden_state_shape = (lstm.num_layers, self.n_envs, lstm.hidden_size) + # hidden states for actor and critic self.lstm_states = ( - th.zeros(single_hidden_state_shape).to(self.device), - th.zeros(single_hidden_state_shape).to(self.device), + ( + th.zeros(single_hidden_state_shape).to(self.device), + th.zeros(single_hidden_state_shape).to(self.device), + ), + ( + th.zeros(single_hidden_state_shape).to(self.device), + th.zeros(single_hidden_state_shape).to(self.device), + ), ) self.rollout_buffer = buffer_cls( @@ -298,8 +305,8 @@ def collect_rollouts( terminal_obs = self.policy.obs_to_tensor(infos[idx]["terminal_observation"])[0] with th.no_grad(): terminal_lstm_state = ( - lstm_states[0][:, idx : idx + 1, :], - lstm_states[1][:, idx : idx + 1, :], + lstm_states[1][0][:, idx : idx + 1, :], + lstm_states[1][1][:, idx : idx + 1, :], ) episode_starts = th.tensor([False]).float().to(self.device) terminal_value = self.policy.predict_values(terminal_obs, terminal_lstm_state, episode_starts)[0] @@ -323,7 +330,7 @@ def collect_rollouts( with th.no_grad(): # Compute value for the last timestep episode_starts = th.tensor(dones).float().to(self.device) - values = self.policy.predict_values(obs_as_tensor(new_obs, self.device), lstm_states, episode_starts) + values = self.policy.predict_values(obs_as_tensor(new_obs, self.device), lstm_states[1], episode_starts) rollout_buffer.compute_returns_and_advantage(last_values=values, dones=dones) From 0c8ab150add2990c4fa0d1fc09e91bf6b3a2ac4f Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Thu, 25 Nov 2021 17:02:50 +0100 Subject: [PATCH 10/50] Re-enable batch size --- sb3_contrib/common/recurrent/buffers.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/sb3_contrib/common/recurrent/buffers.py b/sb3_contrib/common/recurrent/buffers.py index 723ba58e..882108eb 100644 --- a/sb3_contrib/common/recurrent/buffers.py +++ b/sb3_contrib/common/recurrent/buffers.py @@ -104,18 +104,13 @@ def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentRolloutBuf if batch_size is None: batch_size = self.buffer_size * self.n_envs - # indices = np.arange(self.buffer_size * self.n_envs) - # start_idx = 0 - # while start_idx < self.buffer_size * self.n_envs: - # yield self._get_samples(indices[start_idx : start_idx + batch_size]) - # start_idx += batch_size - # Do not shuffle the sequence, only the env indices - # n_minibatches = (self.buffer_size * self.n_envs) // batch_size - n_minibatches = 1 - # assert ( - # self.n_envs % n_minibatches == 0 - # ), f"{self.n_envs} not a multiple of {n_minibatches} = {self.buffer_size * self.n_envs} // {batch_size}" + n_minibatches = (self.buffer_size * self.n_envs) // batch_size + # n_minibatches = 1 + + assert ( + self.n_envs % n_minibatches == 0 + ), f"{self.n_envs} not a multiple of {n_minibatches} = {self.buffer_size * self.n_envs} // {batch_size}" n_envs_per_batch = self.n_envs // n_minibatches # n_envs_per_batch = batch_size // self.buffer_size From eb1e6c1bcf2fd6afb53ce7d8f0726ef9fd64c28e Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Thu, 25 Nov 2021 17:06:17 +0100 Subject: [PATCH 11/50] Deactivate vf rnn --- sb3_contrib/ppo_lstm/ppo_lstm.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/sb3_contrib/ppo_lstm/ppo_lstm.py b/sb3_contrib/ppo_lstm/ppo_lstm.py index 2a3e8e0d..41172e78 100644 --- a/sb3_contrib/ppo_lstm/ppo_lstm.py +++ b/sb3_contrib/ppo_lstm/ppo_lstm.py @@ -304,10 +304,11 @@ def collect_rollouts( ): terminal_obs = self.policy.obs_to_tensor(infos[idx]["terminal_observation"])[0] with th.no_grad(): - terminal_lstm_state = ( - lstm_states[1][0][:, idx : idx + 1, :], - lstm_states[1][1][:, idx : idx + 1, :], - ) + # terminal_lstm_state = ( + # lstm_states[1][0][:, idx : idx + 1, :], + # lstm_states[1][1][:, idx : idx + 1, :], + # ) + terminal_lstm_state = None episode_starts = th.tensor([False]).float().to(self.device) terminal_value = self.policy.predict_values(terminal_obs, terminal_lstm_state, episode_starts)[0] rewards[idx] += self.gamma * terminal_value From f0133468e70e5672f7416c63ac2c75b148b28190 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Thu, 25 Nov 2021 20:42:05 +0100 Subject: [PATCH 12/50] Allow any batch size --- sb3_contrib/common/recurrent/buffers.py | 94 ++++++++++++++++++------- sb3_contrib/ppo_lstm/ppo_lstm.py | 11 +-- tests/test_lstm.py | 21 +++--- 3 files changed, 87 insertions(+), 39 deletions(-) diff --git a/sb3_contrib/common/recurrent/buffers.py b/sb3_contrib/common/recurrent/buffers.py index 882108eb..206f803c 100644 --- a/sb3_contrib/common/recurrent/buffers.py +++ b/sb3_contrib/common/recurrent/buffers.py @@ -49,32 +49,30 @@ def __init__( buffer_size: int, observation_space: spaces.Space, action_space: spaces.Space, - # lstm_states: Tuple[np.ndarray, np.ndarray], + lstm_states: Tuple[np.ndarray, np.ndarray], device: Union[th.device, str] = "cpu", gae_lambda: float = 1, gamma: float = 0.99, n_envs: int = 1, ): - # self.lstm_states = lstm_states + self.lstm_states = lstm_states # self.dones = None self.initial_lstm_states = None super().__init__(buffer_size, observation_space, action_space, device, gae_lambda, gamma, n_envs) def reset(self): super().reset() - # self.hidden_states = np.zeros_like(self.lstm_states[0]) - # self.cell_states = np.zeros_like(self.lstm_states[1]) - # self.dones = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) - - # def add(self, *args, lstm_states: Tuple[np.ndarray, np.ndarray], **kwargs) -> None: - # """ - # :param hidden_states: LSTM cell and hidden state - # """ - # self.hidden_states[self.pos] = np.array(lstm_states[0]) - # self.cell_states[self.pos] = np.array(lstm_states[1]) - # self.dones[self.pos] = np.array(dones) - # - # super().add(*args, **kwargs) + self.hidden_states = np.zeros_like(self.lstm_states[0]) + self.cell_states = np.zeros_like(self.lstm_states[1]) + + def add(self, *args, lstm_states: Tuple[np.ndarray, np.ndarray], **kwargs) -> None: + """ + :param hidden_states: LSTM cell and hidden state + """ + self.hidden_states[self.pos] = np.array(lstm_states[0]) + self.cell_states[self.pos] = np.array(lstm_states[1]) + + super().add(*args, **kwargs) def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentRolloutBufferSamples, None, None]: assert self.full, "" @@ -83,8 +81,8 @@ def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentRolloutBuf if not self.generator_ready: # hidden_state_shape = (self.n_steps, lstm.num_layers, self.n_envs, lstm.hidden_size) # swap first to (self.n_steps, self.n_envs, lstm.num_layers, lstm.hidden_size) - # self.hidden_states = self.hidden_states.swapaxes(1, 2) - # self.cell_states = self.cell_states.swapaxes(1, 2) + self.hidden_states = self.hidden_states.swapaxes(1, 2) + self.cell_states = self.cell_states.swapaxes(1, 2) for tensor in [ "observations", @@ -93,8 +91,8 @@ def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentRolloutBuf "log_probs", "advantages", "returns", - # "hidden_states", - # "cell_states", + "hidden_states", + "cell_states", "episode_starts", ]: self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor]) @@ -104,9 +102,58 @@ def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentRolloutBuf if batch_size is None: batch_size = self.buffer_size * self.n_envs - # Do not shuffle the sequence, only the env indices + any_batch_size = True + if any_batch_size: + # No shuffling implemented yet + indices = np.arange(self.buffer_size * self.n_envs) + env_change = np.zeros(self.buffer_size * self.n_envs).reshape(self.buffer_size, self.n_envs) + env_change[0, :] = 1.0 + env_change = self.swap_and_flatten(env_change) + start_idx = 0 + while start_idx < self.buffer_size * self.n_envs: + batch_inds = indices[start_idx : start_idx + batch_size] + # Create sequence if env change too + seq_start = np.logical_or(self.episode_starts[batch_inds], env_change[batch_inds]) + starts = np.where(seq_start == True)[0] + ends = np.concatenate([(starts - 1)[1:], np.array([len(batch_inds)])]) + def pad(tensor: np.ndarray): + seq = [self.to_torch(tensor[start:end + 1]) for start, end in zip(starts, ends)] + return th.nn.utils.rnn.pad_sequence(seq) + + n_layers = self.hidden_states.shape[1] + n_seq = len(starts) + max_length = pad(self.actions[batch_inds]).shape[0] + # TODO: output mask to not backpropagate everywhere + padded_batch_size = n_seq * max_length + lstm_states_pi = ( + # (n_steps, n_layers, n_envs, dim) -> (n_layers, n_seq, dim) + self.hidden_states[batch_inds][seq_start == True].reshape(n_layers, n_seq, -1), + self.cell_states[batch_inds][seq_start == True].reshape(n_layers, n_seq, -1), + ) + lstm_states_pi = ( + self.to_torch(lstm_states_pi[0]), self.to_torch(lstm_states_pi[1]) + ) + + lstm_states_vf = None + + yield RecurrentRolloutBufferSamples( + observations=pad(self.observations[batch_inds]).swapaxes(0, 1).reshape((padded_batch_size,) + self.obs_shape), + actions=pad(self.actions[batch_inds]).swapaxes(0, 1).reshape((padded_batch_size,) + self.actions.shape[1:]), + old_values=pad(self.values[batch_inds]).swapaxes(0, 1).flatten(), + old_log_prob=pad(self.log_probs[batch_inds]).swapaxes(0, 1).flatten(), + advantages=pad(self.advantages[batch_inds]).swapaxes(0, 1).flatten(), + returns=pad(self.returns[batch_inds]).swapaxes(0, 1).flatten(), + lstm_states=(lstm_states_pi, lstm_states_vf), + episode_starts=pad(self.episode_starts[batch_inds]).swapaxes(0, 1).flatten(), + ) + start_idx += batch_size + + if any_batch_size: + return + + # Baselines way of sampling, constraint in the batch size + # and number of environments n_minibatches = (self.buffer_size * self.n_envs) // batch_size - # n_minibatches = 1 assert ( self.n_envs % n_minibatches == 0 @@ -114,6 +161,7 @@ def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentRolloutBuf n_envs_per_batch = self.n_envs // n_minibatches # n_envs_per_batch = batch_size // self.buffer_size + # Do not shuffle the sequence, only the env indices env_indices = np.random.permutation(self.n_envs) flat_indices = np.arange(self.buffer_size * self.n_envs).reshape(self.n_envs, self.buffer_size) @@ -121,7 +169,7 @@ def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentRolloutBuf end_env_idx = start_env_idx + n_envs_per_batch mini_batch_env_indices = env_indices[start_env_idx:end_env_idx] batch_inds = flat_indices[mini_batch_env_indices].ravel() - # lstm_states = ( + # lstm_states_pi = ( # self.hidden_states[:, :, mini_batch_env_indices, :][0], # self.cell_states[:, :, mini_batch_env_indices, :][0], # ) @@ -142,9 +190,7 @@ def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentRolloutBuf old_log_prob=self.to_torch(self.log_probs[batch_inds].flatten()), advantages=self.to_torch(self.advantages[batch_inds].flatten()), returns=self.to_torch(self.returns[batch_inds].flatten()), - # lstm_states=(self.to_torch(lstm_states[0]), self.to_torch(lstm_states[1])), lstm_states=(lstm_states_pi, lstm_states_vf), - # dones=self.to_torch(self.dones[batch_inds].flatten()), episode_starts=self.to_torch(self.episode_starts[batch_inds].flatten()), ) diff --git a/sb3_contrib/ppo_lstm/ppo_lstm.py b/sb3_contrib/ppo_lstm/ppo_lstm.py index 41172e78..a943c70b 100644 --- a/sb3_contrib/ppo_lstm/ppo_lstm.py +++ b/sb3_contrib/ppo_lstm/ppo_lstm.py @@ -148,8 +148,8 @@ def _setup_model(self) -> None: if not isinstance(self.policy, RecurrentActorCriticPolicy): raise ValueError("Policy must subclass RecurrentActorCriticPolicy") - # hidden_state_shape = (self.n_steps, lstm.num_layers, self.n_envs, lstm.hidden_size) - # lstm_states = (np.zeros(hidden_state_shape, dtype=np.float32), np.zeros(hidden_state_shape, dtype=np.float32)) + hidden_state_shape = (self.n_steps, lstm.num_layers, self.n_envs, lstm.hidden_size) + lstm_states = (np.zeros(hidden_state_shape, dtype=np.float32), np.zeros(hidden_state_shape, dtype=np.float32)) single_hidden_state_shape = (lstm.num_layers, self.n_envs, lstm.hidden_size) # hidden states for actor and critic @@ -168,7 +168,7 @@ def _setup_model(self) -> None: self.n_steps, self.observation_space, self.action_space, - # lstm_states, + lstm_states, self.device, gamma=self.gamma, gae_lambda=self.gae_lambda, @@ -320,13 +320,14 @@ def collect_rollouts( self._last_episode_starts, values, log_probs, - # lstm_states=(lstm_states[0].cpu().numpy(), lstm_states[1].cpu().numpy()), + lstm_states=(self.lstm_states[0][0].cpu().numpy(), self.lstm_states[0][1].cpu().numpy()), ) self._last_obs = new_obs self._last_episode_starts = dones + self.lstm_states = lstm_states - self.lstm_states = deepcopy(lstm_states) + # self.lstm_states = deepcopy(lstm_states) with th.no_grad(): # Compute value for the last timestep diff --git a/tests/test_lstm.py b/tests/test_lstm.py index 9b9c3514..801c924b 100644 --- a/tests/test_lstm.py +++ b/tests/test_lstm.py @@ -36,30 +36,31 @@ def step(self, action): def test_ppo_lstm(): from stable_baselines3.common.env_util import make_vec_env - env = make_vec_env("CartPole-v1", n_envs=8) + # env = make_vec_env("CartPole-v1", n_envs=8) def make_env(): env = CartPoleNoVelEnv() env = TimeLimit(env, max_episode_steps=500) return env - # env = make_vec_env(make_env, n_envs=8) + env = make_vec_env(make_env, n_envs=8) model = RecurrentPPO( "MlpLstmPolicy", env, - n_steps=8, - learning_rate=2.5e-4, + n_steps=32, + learning_rate=0.0007, verbose=1, - batch_size=64, + batch_size=256, seed=0, - gae_lambda=0.95, - policy_kwargs=dict(ortho_init=False), - # ent_coef=0.01, + n_epochs=10, + # max_grad_norm=1, + gae_lambda=0.98, + policy_kwargs=dict(net_arch=[dict(vf=[64])], ortho_init=False), # policy_kwargs=dict(net_arch=[dict(pi=[64], vf=[64])]) # create_eval_env=True, ) # model.learn(total_timesteps=500, eval_freq=250) - model.learn(total_timesteps=22_000) - # model.learn(total_timesteps=100) + # model.learn(total_timesteps=1_000_000) + model.learn(total_timesteps=100) From a14f2cea98112104cdc648db2db6582bd485dacf Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Fri, 26 Nov 2021 11:18:46 +0100 Subject: [PATCH 13/50] Add support for evaluation --- sb3_contrib/common/recurrent/buffers.py | 95 ++++++++++++++---------- sb3_contrib/common/recurrent/policies.py | 53 ++++++------- sb3_contrib/ppo_lstm/ppo_lstm.py | 2 - tests/test_lstm.py | 25 +++++-- 4 files changed, 99 insertions(+), 76 deletions(-) diff --git a/sb3_contrib/common/recurrent/buffers.py b/sb3_contrib/common/recurrent/buffers.py index 206f803c..1b5ae891 100644 --- a/sb3_contrib/common/recurrent/buffers.py +++ b/sb3_contrib/common/recurrent/buffers.py @@ -54,10 +54,13 @@ def __init__( gae_lambda: float = 1, gamma: float = 0.99, n_envs: int = 1, + sampling_style: str = "default", # "defaults" or "per_env" ): self.lstm_states = lstm_states # self.dones = None self.initial_lstm_states = None + self.sampling_style = sampling_style + self.starts, self.ends = None, None super().__init__(buffer_size, observation_space, action_space, device, gae_lambda, gamma, n_envs) def reset(self): @@ -102,53 +105,24 @@ def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentRolloutBuf if batch_size is None: batch_size = self.buffer_size * self.n_envs - any_batch_size = True - if any_batch_size: - # No shuffling implemented yet + if self.sampling_style == "default": + # No shuffling + # indices = np.arange(self.buffer_size * self.n_envs) + # Trick to shuffle a bit: keep the sequence order + # but split the indices in two + split_index = np.random.randint(self.buffer_size * self.n_envs) indices = np.arange(self.buffer_size * self.n_envs) + indices = np.concatenate((indices[split_index:], indices[:split_index])) + env_change = np.zeros(self.buffer_size * self.n_envs).reshape(self.buffer_size, self.n_envs) env_change[0, :] = 1.0 env_change = self.swap_and_flatten(env_change) + start_idx = 0 while start_idx < self.buffer_size * self.n_envs: batch_inds = indices[start_idx : start_idx + batch_size] - # Create sequence if env change too - seq_start = np.logical_or(self.episode_starts[batch_inds], env_change[batch_inds]) - starts = np.where(seq_start == True)[0] - ends = np.concatenate([(starts - 1)[1:], np.array([len(batch_inds)])]) - def pad(tensor: np.ndarray): - seq = [self.to_torch(tensor[start:end + 1]) for start, end in zip(starts, ends)] - return th.nn.utils.rnn.pad_sequence(seq) - - n_layers = self.hidden_states.shape[1] - n_seq = len(starts) - max_length = pad(self.actions[batch_inds]).shape[0] - # TODO: output mask to not backpropagate everywhere - padded_batch_size = n_seq * max_length - lstm_states_pi = ( - # (n_steps, n_layers, n_envs, dim) -> (n_layers, n_seq, dim) - self.hidden_states[batch_inds][seq_start == True].reshape(n_layers, n_seq, -1), - self.cell_states[batch_inds][seq_start == True].reshape(n_layers, n_seq, -1), - ) - lstm_states_pi = ( - self.to_torch(lstm_states_pi[0]), self.to_torch(lstm_states_pi[1]) - ) - - lstm_states_vf = None - - yield RecurrentRolloutBufferSamples( - observations=pad(self.observations[batch_inds]).swapaxes(0, 1).reshape((padded_batch_size,) + self.obs_shape), - actions=pad(self.actions[batch_inds]).swapaxes(0, 1).reshape((padded_batch_size,) + self.actions.shape[1:]), - old_values=pad(self.values[batch_inds]).swapaxes(0, 1).flatten(), - old_log_prob=pad(self.log_probs[batch_inds]).swapaxes(0, 1).flatten(), - advantages=pad(self.advantages[batch_inds]).swapaxes(0, 1).flatten(), - returns=pad(self.returns[batch_inds]).swapaxes(0, 1).flatten(), - lstm_states=(lstm_states_pi, lstm_states_vf), - episode_starts=pad(self.episode_starts[batch_inds]).swapaxes(0, 1).flatten(), - ) + yield self._get_samples(batch_inds, env_change) start_idx += batch_size - - if any_batch_size: return # Baselines way of sampling, constraint in the batch size @@ -159,7 +133,6 @@ def pad(tensor: np.ndarray): self.n_envs % n_minibatches == 0 ), f"{self.n_envs} not a multiple of {n_minibatches} = {self.buffer_size * self.n_envs} // {batch_size}" n_envs_per_batch = self.n_envs // n_minibatches - # n_envs_per_batch = batch_size // self.buffer_size # Do not shuffle the sequence, only the env indices env_indices = np.random.permutation(self.n_envs) @@ -194,6 +167,46 @@ def pad(tensor: np.ndarray): episode_starts=self.to_torch(self.episode_starts[batch_inds].flatten()), ) + def pad(self, tensor: np.ndarray) -> th.Tensor: + seq = [self.to_torch(tensor[start : end + 1]) for start, end in zip(self.starts, self.ends)] + return th.nn.utils.rnn.pad_sequence(seq) + + def _get_samples( + self, + batch_inds: np.ndarray, + env_change: np.ndarray, + env: Optional[VecNormalize] = None, + ) -> RecurrentRolloutBufferSamples: + # Create sequence if env change too + seq_start = np.logical_or(self.episode_starts[batch_inds], env_change[batch_inds]) + self.starts = np.where(seq_start == True)[0] # noqa: E712 + self.ends = np.concatenate([(self.starts - 1)[1:], np.array([len(batch_inds)])]) + + n_layers = self.hidden_states.shape[1] + n_seq = len(self.starts) + max_length = self.pad(self.actions[batch_inds]).shape[0] + # TODO: output mask to not backpropagate everywhere + padded_batch_size = n_seq * max_length + lstm_states_pi = ( + # (n_steps, n_layers, n_envs, dim) -> (n_layers, n_seq, dim) + self.hidden_states[batch_inds][seq_start == True].reshape(n_layers, n_seq, -1), # noqa: E712 + self.cell_states[batch_inds][seq_start == True].reshape(n_layers, n_seq, -1), # noqa: E712 + ) + lstm_states_pi = (self.to_torch(lstm_states_pi[0]), self.to_torch(lstm_states_pi[1])) + + lstm_states_vf = None + + return RecurrentRolloutBufferSamples( + observations=self.pad(self.observations[batch_inds]).swapaxes(0, 1).reshape((padded_batch_size,) + self.obs_shape), + actions=self.pad(self.actions[batch_inds]).swapaxes(0, 1).reshape((padded_batch_size,) + self.actions.shape[1:]), + old_values=self.pad(self.values[batch_inds]).swapaxes(0, 1).flatten(), + old_log_prob=self.pad(self.log_probs[batch_inds]).swapaxes(0, 1).flatten(), + advantages=self.pad(self.advantages[batch_inds]).swapaxes(0, 1).flatten(), + returns=self.pad(self.returns[batch_inds]).swapaxes(0, 1).flatten(), + lstm_states=(lstm_states_pi, lstm_states_vf), + episode_starts=self.pad(self.episode_starts[batch_inds]).swapaxes(0, 1).flatten(), + ) + class RecurrentDictRolloutBuffer(DictRolloutBuffer): """ @@ -302,5 +315,5 @@ def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = Non advantages=self.to_torch(self.advantages[batch_inds].flatten()), returns=self.to_torch(self.returns[batch_inds].flatten()), lstm_states=(self.to_torch(self.hidden_states[batch_inds]), self.to_torch(self.cell_states[batch_inds])), - dones=self.to_torch(self.dones[batch_inds]), + episode_starts=self.to_torch(self.episode_starts[batch_inds].flatten()), ) diff --git a/sb3_contrib/common/recurrent/policies.py b/sb3_contrib/common/recurrent/policies.py index ace1de62..91e386dc 100644 --- a/sb3_contrib/common/recurrent/policies.py +++ b/sb3_contrib/common/recurrent/policies.py @@ -5,7 +5,6 @@ import torch as th from stable_baselines3.common.distributions import Distribution from stable_baselines3.common.policies import ActorCriticPolicy -from stable_baselines3.common.preprocessing import preprocess_obs from stable_baselines3.common.torch_layers import BaseFeaturesExtractor, FlattenExtractor, MlpExtractor from stable_baselines3.common.type_aliases import Schedule from stable_baselines3.common.utils import zip_strict @@ -92,6 +91,7 @@ def __init__( num_layers = 1 self.lstm_actor = nn.LSTM(self.features_dim, hidden_size, num_layers=num_layers) + self.lstm_shape = (num_layers, 1, hidden_size) # self.lstm_critic = nn.LSTM(self.features_dim, hidden_size, num_layers=num_layers) self.critic = nn.Linear(self.features_dim, hidden_size) # Setup optimizer with initial learning rate @@ -144,7 +144,7 @@ def forward( lstm_states: Tuple[th.Tensor, th.Tensor], episode_starts: th.Tensor, deterministic: bool = False, - ) -> Tuple[th.Tensor, th.Tensor, th.Tensor]: + ) -> Tuple[th.Tensor, th.Tensor, th.Tensor, Tuple[th.Tensor, ...]]: """ Forward pass in all the networks (actor and critic) @@ -156,6 +156,7 @@ def forward( features = self.extract_features(obs) # latent_pi, latent_vf = self.mlp_extractor(features) latent_pi, lstm_states_pi = self._process_sequence(features, lstm_states[0], episode_starts, self.lstm_actor) + # TODO: try re-using LSTM features for value function but using detach # latent_vf, lstm_states_vf = self._process_sequence(features, lstm_states[1], episode_starts, self.lstm_critic) lstm_states_vf = None latent_vf = self.critic(features) @@ -175,17 +176,17 @@ def get_distribution( obs: th.Tensor, lstm_states: Tuple[th.Tensor, th.Tensor], episode_starts: th.Tensor, - ) -> Distribution: + ) -> Tuple[Distribution, Tuple[th.Tensor, ...]]: """ Get the current policy distribution given the observations. :param obs: - :return: the action distribution. + :return: the action distribution and new hidden states. """ features = self.extract_features(obs) - latent_pi, _ = self._process_sequence(features, lstm_states, episode_starts, self.lstm_actor) + latent_pi, lstm_states = self._process_sequence(features, lstm_states, episode_starts, self.lstm_actor) latent_pi = self.mlp_extractor.forward_actor(latent_pi) - return self._get_action_dist_from_latent(latent_pi) + return self._get_action_dist_from_latent(latent_pi), lstm_states def predict_values( self, @@ -228,7 +229,6 @@ def evaluate_actions( # latent_vf, _ = self._process_sequence(features, lstm_states[1], episode_starts, self.lstm_critic) latent_vf = self.critic(features) - latent_pi = self.mlp_extractor.forward_actor(latent_pi) latent_vf = self.mlp_extractor.forward_critic(latent_vf) @@ -243,16 +243,17 @@ def _predict( lstm_states: Tuple[th.Tensor, th.Tensor], episode_starts: th.Tensor, deterministic: bool = False, - ) -> th.Tensor: + ) -> Tuple[th.Tensor, Tuple[th.Tensor, ...]]: """ Get the action according to the policy for a given observation. :param observation: :param lstm_states: :param deterministic: Whether to use stochastic or deterministic actions - :return: Taken action according to the policy + :return: Taken action according to the policy and hidden states of the RNN """ - return self.get_distribution(observation, lstm_states, episode_starts).get_actions(deterministic=deterministic) + distribution, lstm_states = self.get_distribution(observation, lstm_states, episode_starts) + return distribution.get_actions(deterministic=deterministic), lstm_states def predict( self, @@ -260,38 +261,40 @@ def predict( state: Optional[np.ndarray] = None, mask: Optional[np.ndarray] = None, deterministic: bool = False, - ) -> Tuple[np.ndarray, Optional[np.ndarray]]: + ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]: """ Get the policy action and state from an observation (and optional state). Includes sugar-coating to handle different observations (e.g. normalizing images). :param observation: the input observation - :param state: The last states (can be None, used in recurrent policies) + :param state: The last hidden states (can be None, used in recurrent policies) :param mask: The last masks (can be None, used in recurrent policies) + this correspond to beginning of episodes, + where the hidden states of the RNN must be reset. :param deterministic: Whether or not to return deterministic actions. - :return: the model's action and the next state + :return: the model's action and the next hidden state (used in recurrent policies) """ - # TODO (GH/1): add support for RNN policies - # if state is None: - # state = self.features_extractor.initial_lstm_states - # if mask is None: - # mask = [False for _ in range(self.n_envs)] # Switch to eval mode (this affects batch norm / dropout) self.set_training_mode(False) observation, vectorized_env = self.obs_to_tensor(observation) - # TODO(antonin): preprocess state - # if state is not None: - # lstm_states = None + n_envs = observation.shape[0] + # state : (n_layers, n_envs, dim) + if state is None: + state = np.concatenate([np.zeros(self.lstm_shape) for _ in range(n_envs)], axis=1) + state = (state, state) + + if mask is None: + mask = np.array([False for _ in range(n_envs)]) with th.no_grad(): - actions = self._predict(observation, lstm_states=state, deterministic=deterministic) - states = self.get_lstm_states() + # Convert to PyTorch tensors + states = th.tensor(state[0]).float().to(self.device), th.tensor(state[1]).float().to(self.device) + mask = th.tensor(mask).float().to(self.device) + actions, states = self._predict(observation, lstm_states=states, episode_starts=mask, deterministic=deterministic) states = (states[0].cpu().numpy(), states[1].cpu().numpy()) - # TODO(antonin): fix eval script - states = None # Convert to numpy actions = actions.cpu().numpy() diff --git a/sb3_contrib/ppo_lstm/ppo_lstm.py b/sb3_contrib/ppo_lstm/ppo_lstm.py index a943c70b..cdb087d5 100644 --- a/sb3_contrib/ppo_lstm/ppo_lstm.py +++ b/sb3_contrib/ppo_lstm/ppo_lstm.py @@ -327,8 +327,6 @@ def collect_rollouts( self._last_episode_starts = dones self.lstm_states = lstm_states - # self.lstm_states = deepcopy(lstm_states) - with th.no_grad(): # Compute value for the last timestep episode_starts = th.tensor(dones).float().to(self.device) diff --git a/tests/test_lstm.py b/tests/test_lstm.py index 801c924b..3f93c239 100644 --- a/tests/test_lstm.py +++ b/tests/test_lstm.py @@ -3,6 +3,10 @@ from gym.envs.classic_control import CartPoleEnv from gym.wrappers.time_limit import TimeLimit +# from stable_baselines3.common.callbacks import EvalCallback +from stable_baselines3.common.env_util import make_vec_env +from stable_baselines3.common.evaluation import evaluate_policy + from sb3_contrib import RecurrentPPO @@ -34,16 +38,21 @@ def step(self, action): def test_ppo_lstm(): - from stable_baselines3.common.env_util import make_vec_env - # env = make_vec_env("CartPole-v1", n_envs=8) + env = make_vec_env("CartPole-v1", n_envs=16) def make_env(): env = CartPoleNoVelEnv() env = TimeLimit(env, max_episode_steps=500) return env - env = make_vec_env(make_env, n_envs=8) + env = make_vec_env(make_env, n_envs=16) + + # eval_callback = EvalCallback( + # make_vec_env(make_env, n_envs=4), + # n_eval_episodes=20, + # eval_freq=250 // env.num_envs, + # ) model = RecurrentPPO( "MlpLstmPolicy", @@ -53,14 +62,14 @@ def make_env(): verbose=1, batch_size=256, seed=0, - n_epochs=10, + n_epochs=20, # max_grad_norm=1, gae_lambda=0.98, policy_kwargs=dict(net_arch=[dict(vf=[64])], ortho_init=False), # policy_kwargs=dict(net_arch=[dict(pi=[64], vf=[64])]) - # create_eval_env=True, ) - # model.learn(total_timesteps=500, eval_freq=250) - # model.learn(total_timesteps=1_000_000) - model.learn(total_timesteps=100) + model.learn(total_timesteps=250) + # model.learn(total_timesteps=100_000) + # model.learn(total_timesteps=1000, callback=eval_callback) + evaluate_policy(model, env) From 362dec4bf834430aebb9222e6e62df41f9a27ade Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Fri, 26 Nov 2021 11:40:49 +0100 Subject: [PATCH 14/50] Add CNN support --- sb3_contrib/common/recurrent/policies.py | 160 ++++++++++++++++++++++- sb3_contrib/ppo_lstm/__init__.py | 3 +- sb3_contrib/ppo_lstm/policies.py | 9 +- sb3_contrib/ppo_lstm/ppo_lstm.py | 1 + tests/test_lstm.py | 45 ++++++- 5 files changed, 206 insertions(+), 12 deletions(-) diff --git a/sb3_contrib/common/recurrent/policies.py b/sb3_contrib/common/recurrent/policies.py index 91e386dc..dd5cfbf4 100644 --- a/sb3_contrib/common/recurrent/policies.py +++ b/sb3_contrib/common/recurrent/policies.py @@ -5,13 +5,17 @@ import torch as th from stable_baselines3.common.distributions import Distribution from stable_baselines3.common.policies import ActorCriticPolicy -from stable_baselines3.common.torch_layers import BaseFeaturesExtractor, FlattenExtractor, MlpExtractor +from stable_baselines3.common.torch_layers import ( + BaseFeaturesExtractor, + CombinedExtractor, + FlattenExtractor, + MlpExtractor, + NatureCNN, +) from stable_baselines3.common.type_aliases import Schedule from stable_baselines3.common.utils import zip_strict from torch import nn -# CombinedExtractor,; FlattenExtractor,; MlpExtractor,; NatureCNN,; create_mlp, - class RecurrentActorCriticPolicy(ActorCriticPolicy): """ @@ -313,3 +317,153 @@ def predict( actions = actions[0] return actions, states + + +class RecurrentActorCriticCnnPolicy(RecurrentActorCriticPolicy): + """ + CNN policy class for actor-critic algorithms (has both policy and value prediction). + Used by A2C, PPO and the likes. + + :param observation_space: Observation space + :param action_space: Action space + :param lr_schedule: Learning rate schedule (could be constant) + :param net_arch: The specification of the policy and value networks. + :param activation_fn: Activation function + :param ortho_init: Whether to use or not orthogonal initialization + :param use_sde: Whether to use State Dependent Exploration or not + :param log_std_init: Initial value for the log standard deviation + :param full_std: Whether to use (n_features x n_actions) parameters + for the std instead of only (n_features,) when using gSDE + :param sde_net_arch: Network architecture for extracting features + when using gSDE. If None, the latent features from the policy will be used. + Pass an empty list to use the states as features. + :param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure + a positive standard deviation (cf paper). It allows to keep variance + above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough. + :param squash_output: Whether to squash the output using a tanh function, + this allows to ensure boundaries when using gSDE. + :param features_extractor_class: Features extractor to use. + :param features_extractor_kwargs: Keyword arguments + to pass to the features extractor. + :param normalize_images: Whether to normalize images or not, + dividing by 255.0 (True by default) + :param optimizer_class: The optimizer to use, + ``th.optim.Adam`` by default + :param optimizer_kwargs: Additional keyword arguments, + excluding the learning rate, to pass to the optimizer + """ + + def __init__( + self, + observation_space: gym.spaces.Space, + action_space: gym.spaces.Space, + lr_schedule: Schedule, + net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None, + activation_fn: Type[nn.Module] = nn.Tanh, + ortho_init: bool = True, + use_sde: bool = False, + log_std_init: float = 0.0, + full_std: bool = True, + sde_net_arch: Optional[List[int]] = None, + use_expln: bool = False, + squash_output: bool = False, + features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN, + features_extractor_kwargs: Optional[Dict[str, Any]] = None, + normalize_images: bool = True, + optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[Dict[str, Any]] = None, + ): + super().__init__( + observation_space, + action_space, + lr_schedule, + net_arch, + activation_fn, + ortho_init, + use_sde, + log_std_init, + full_std, + sde_net_arch, + use_expln, + squash_output, + features_extractor_class, + features_extractor_kwargs, + normalize_images, + optimizer_class, + optimizer_kwargs, + ) + + +class MultiInputRecurrentActorCriticPolicy(RecurrentActorCriticPolicy): + """ + MultiInputActorClass policy class for actor-critic algorithms (has both policy and value prediction). + Used by A2C, PPO and the likes. + + :param observation_space: Observation space + :param action_space: Action space + :param lr_schedule: Learning rate schedule (could be constant) + :param net_arch: The specification of the policy and value networks. + :param activation_fn: Activation function + :param ortho_init: Whether to use or not orthogonal initialization + :param use_sde: Whether to use State Dependent Exploration or not + :param log_std_init: Initial value for the log standard deviation + :param full_std: Whether to use (n_features x n_actions) parameters + for the std instead of only (n_features,) when using gSDE + :param sde_net_arch: Network architecture for extracting features + when using gSDE. If None, the latent features from the policy will be used. + Pass an empty list to use the states as features. + :param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure + a positive standard deviation (cf paper). It allows to keep variance + above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough. + :param squash_output: Whether to squash the output using a tanh function, + this allows to ensure boundaries when using gSDE. + :param features_extractor_class: Features extractor to use. + :param features_extractor_kwargs: Keyword arguments + to pass to the features extractor. + :param normalize_images: Whether to normalize images or not, + dividing by 255.0 (True by default) + :param optimizer_class: The optimizer to use, + ``th.optim.Adam`` by default + :param optimizer_kwargs: Additional keyword arguments, + excluding the learning rate, to pass to the optimizer + """ + + def __init__( + self, + observation_space: gym.spaces.Space, + action_space: gym.spaces.Space, + lr_schedule: Schedule, + net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None, + activation_fn: Type[nn.Module] = nn.Tanh, + ortho_init: bool = True, + use_sde: bool = False, + log_std_init: float = 0.0, + full_std: bool = True, + sde_net_arch: Optional[List[int]] = None, + use_expln: bool = False, + squash_output: bool = False, + features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractor, + features_extractor_kwargs: Optional[Dict[str, Any]] = None, + normalize_images: bool = True, + optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[Dict[str, Any]] = None, + ): + super().__init__( + observation_space, + action_space, + lr_schedule, + net_arch, + activation_fn, + ortho_init, + use_sde, + log_std_init, + full_std, + sde_net_arch, + use_expln, + squash_output, + features_extractor_class, + features_extractor_kwargs, + normalize_images, + optimizer_class, + optimizer_kwargs, + ) diff --git a/sb3_contrib/ppo_lstm/__init__.py b/sb3_contrib/ppo_lstm/__init__.py index 92333335..9f4fdd39 100644 --- a/sb3_contrib/ppo_lstm/__init__.py +++ b/sb3_contrib/ppo_lstm/__init__.py @@ -1,3 +1,2 @@ -# from sb3_contrib.ppo_lstm.policies import CnnPolicy, MlpPolicy, MultiInputPolicy -from sb3_contrib.ppo_lstm.policies import MlpLstmPolicy +from sb3_contrib.ppo_lstm.policies import CnnLstmPolicy, MlpLstmPolicy from sb3_contrib.ppo_lstm.ppo_lstm import RecurrentPPO diff --git a/sb3_contrib/ppo_lstm/policies.py b/sb3_contrib/ppo_lstm/policies.py index d0d4d7f3..584a60b5 100644 --- a/sb3_contrib/ppo_lstm/policies.py +++ b/sb3_contrib/ppo_lstm/policies.py @@ -1,13 +1,12 @@ from stable_baselines3.common.policies import register_policy -from sb3_contrib.common.recurrent.policies import ( # RecurrentActorCriticCnnPolicy,; RecurrentMultiInputActorCriticPolicy, - RecurrentActorCriticPolicy, -) +from sb3_contrib.common.recurrent.policies import RecurrentActorCriticCnnPolicy # RecurrentMultiInputActorCriticPolicy, +from sb3_contrib.common.recurrent.policies import RecurrentActorCriticPolicy MlpLstmPolicy = RecurrentActorCriticPolicy -# CnnLstmPolicy = RecurrentActorCriticCnnPolicy +CnnLstmPolicy = RecurrentActorCriticCnnPolicy # MultiInputLstmPolicy = RecurrentMultiInputActorCriticPolicy register_policy("MlpLstmPolicy", RecurrentActorCriticPolicy) -# register_policy("CnnLstmPolicy", RecurrentActorCriticCnnPolicy) +register_policy("CnnLstmPolicy", RecurrentActorCriticCnnPolicy) # register_policy("MultiInputLstmPolicy", RecurrentMultiInputActorCriticPolicy) diff --git a/sb3_contrib/ppo_lstm/ppo_lstm.py b/sb3_contrib/ppo_lstm/ppo_lstm.py index cdb087d5..1bd9b320 100644 --- a/sb3_contrib/ppo_lstm/ppo_lstm.py +++ b/sb3_contrib/ppo_lstm/ppo_lstm.py @@ -139,6 +139,7 @@ def _setup_model(self) -> None: self.observation_space, self.action_space, self.lr_schedule, + use_sde=self.use_sde, **self.policy_kwargs, # pytype:disable=not-instantiable ) self.policy = self.policy.to(self.device) diff --git a/tests/test_lstm.py b/tests/test_lstm.py index 3f93c239..b981bf49 100644 --- a/tests/test_lstm.py +++ b/tests/test_lstm.py @@ -1,4 +1,5 @@ import numpy as np +import pytest from gym import spaces from gym.envs.classic_control import CartPoleEnv from gym.wrappers.time_limit import TimeLimit @@ -37,9 +38,49 @@ def step(self, action): return CartPoleNoVelEnv._pos_obs(full_obs), rew, done, info -def test_ppo_lstm(): +def test_cnn(): + model = RecurrentPPO( + "CnnLstmPolicy", + "Breakout-v0", + n_steps=16, + seed=0, + policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=32)), + ) + + model.learn(total_timesteps=32) + + +@pytest.mark.parametrize("env", ["Pendulum-v0", "CartPole-v1"]) +def test_run(env): + model = RecurrentPPO( + "MlpLstmPolicy", + env, + n_steps=16, + seed=0, + create_eval_env=True, + ) + + model.learn(total_timesteps=32, eval_freq=16) + + +def test_run_sde(): + model = RecurrentPPO( + "MlpLstmPolicy", + "Pendulum-v0", + n_steps=16, + seed=0, + create_eval_env=True, + sde_sample_freq=4, + use_sde=True, + ) + + model.learn(total_timesteps=32, eval_freq=16) + + +def test_ppo_lstm_performance(): - env = make_vec_env("CartPole-v1", n_envs=16) + # env = make_vec_env("CartPole-v1", n_envs=16) + # env = make_vec_env("Pendulum-v0", n_envs=16) def make_env(): env = CartPoleNoVelEnv() From 5b162db8604a052cd20c91b3e3d61fbf4534db5e Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Fri, 26 Nov 2021 12:24:45 +0100 Subject: [PATCH 15/50] Fix start of sequence --- sb3_contrib/common/recurrent/buffers.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sb3_contrib/common/recurrent/buffers.py b/sb3_contrib/common/recurrent/buffers.py index 1b5ae891..a4081ebd 100644 --- a/sb3_contrib/common/recurrent/buffers.py +++ b/sb3_contrib/common/recurrent/buffers.py @@ -179,6 +179,8 @@ def _get_samples( ) -> RecurrentRolloutBufferSamples: # Create sequence if env change too seq_start = np.logical_or(self.episode_starts[batch_inds], env_change[batch_inds]) + # First index is always the beginning of a sequence + seq_start[0] = True self.starts = np.where(seq_start == True)[0] # noqa: E712 self.ends = np.concatenate([(self.starts - 1)[1:], np.array([len(batch_inds)])]) From 954e6dd7d21313df36124f70bf8080767629bc2a Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Fri, 26 Nov 2021 14:07:18 +0100 Subject: [PATCH 16/50] Allow shared LSTM --- sb3_contrib/common/recurrent/buffers.py | 43 ++++--------- sb3_contrib/common/recurrent/policies.py | 67 ++++++++++++++------ sb3_contrib/common/recurrent/type_aliases.py | 31 +++++++++ sb3_contrib/ppo_lstm/ppo_lstm.py | 15 +++-- 4 files changed, 99 insertions(+), 57 deletions(-) create mode 100644 sb3_contrib/common/recurrent/type_aliases.py diff --git a/sb3_contrib/common/recurrent/buffers.py b/sb3_contrib/common/recurrent/buffers.py index a4081ebd..8644fb1a 100644 --- a/sb3_contrib/common/recurrent/buffers.py +++ b/sb3_contrib/common/recurrent/buffers.py @@ -1,33 +1,16 @@ -from typing import Generator, NamedTuple, Optional, Tuple, Union +from typing import Generator, Optional, Tuple, Union import numpy as np import torch as th from gym import spaces from stable_baselines3.common.buffers import DictRolloutBuffer, RolloutBuffer -from stable_baselines3.common.type_aliases import TensorDict from stable_baselines3.common.vec_env import VecNormalize - -class RecurrentRolloutBufferSamples(NamedTuple): - observations: th.Tensor - actions: th.Tensor - old_values: th.Tensor - old_log_prob: th.Tensor - advantages: th.Tensor - returns: th.Tensor - lstm_states: Tuple[th.Tensor, th.Tensor] - episode_starts: th.Tensor - - -class RecurrentDictRolloutBufferSamples(RecurrentRolloutBufferSamples): - observations: TensorDict - actions: th.Tensor - old_values: th.Tensor - old_log_prob: th.Tensor - advantages: th.Tensor - returns: th.Tensor - lstm_states: Tuple[th.Tensor, th.Tensor] - episode_starts: th.Tensor +from sb3_contrib.common.recurrent.type_aliases import ( + RecurrentDictRolloutBufferSamples, + RecurrentRolloutBufferSamples, + RNNStates, +) class RecurrentRolloutBuffer(RolloutBuffer): @@ -147,14 +130,14 @@ def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentRolloutBuf # self.cell_states[:, :, mini_batch_env_indices, :][0], # ) lstm_states_pi = ( - self.initial_lstm_states[0][0][:, mini_batch_env_indices].clone(), - self.initial_lstm_states[0][1][:, mini_batch_env_indices].clone(), + self.initial_lstm_states.pi[0][:, mini_batch_env_indices].clone(), + self.initial_lstm_states.pi[1][:, mini_batch_env_indices].clone(), ) # lstm_states_vf = ( # self.initial_lstm_states[1][0][:, mini_batch_env_indices].clone(), # self.initial_lstm_states[1][1][:, mini_batch_env_indices].clone(), # ) - lstm_states_vf = None + lstm_states_vf = lstm_states_pi yield RecurrentRolloutBufferSamples( observations=self.to_torch(self.observations[batch_inds]), @@ -163,7 +146,7 @@ def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentRolloutBuf old_log_prob=self.to_torch(self.log_probs[batch_inds].flatten()), advantages=self.to_torch(self.advantages[batch_inds].flatten()), returns=self.to_torch(self.returns[batch_inds].flatten()), - lstm_states=(lstm_states_pi, lstm_states_vf), + lstm_states=RNNStates(lstm_states_pi, lstm_states_vf), episode_starts=self.to_torch(self.episode_starts[batch_inds].flatten()), ) @@ -196,7 +179,7 @@ def _get_samples( ) lstm_states_pi = (self.to_torch(lstm_states_pi[0]), self.to_torch(lstm_states_pi[1])) - lstm_states_vf = None + lstm_states_vf = lstm_states_pi return RecurrentRolloutBufferSamples( observations=self.pad(self.observations[batch_inds]).swapaxes(0, 1).reshape((padded_batch_size,) + self.obs_shape), @@ -205,7 +188,7 @@ def _get_samples( old_log_prob=self.pad(self.log_probs[batch_inds]).swapaxes(0, 1).flatten(), advantages=self.pad(self.advantages[batch_inds]).swapaxes(0, 1).flatten(), returns=self.pad(self.returns[batch_inds]).swapaxes(0, 1).flatten(), - lstm_states=(lstm_states_pi, lstm_states_vf), + lstm_states=RNNStates(lstm_states_pi, lstm_states_vf), episode_starts=self.pad(self.episode_starts[batch_inds]).swapaxes(0, 1).flatten(), ) @@ -316,6 +299,6 @@ def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = Non old_log_prob=self.to_torch(self.log_probs[batch_inds].flatten()), advantages=self.to_torch(self.advantages[batch_inds].flatten()), returns=self.to_torch(self.returns[batch_inds].flatten()), - lstm_states=(self.to_torch(self.hidden_states[batch_inds]), self.to_torch(self.cell_states[batch_inds])), + lstm_states=RNNStates(self.to_torch(self.hidden_states[batch_inds]), self.to_torch(self.cell_states[batch_inds])), episode_starts=self.to_torch(self.episode_starts[batch_inds].flatten()), ) diff --git a/sb3_contrib/common/recurrent/policies.py b/sb3_contrib/common/recurrent/policies.py index dd5cfbf4..255b72fa 100644 --- a/sb3_contrib/common/recurrent/policies.py +++ b/sb3_contrib/common/recurrent/policies.py @@ -16,6 +16,8 @@ from stable_baselines3.common.utils import zip_strict from torch import nn +from sb3_contrib.common.recurrent.type_aliases import RNNStates + class RecurrentActorCriticPolicy(ActorCriticPolicy): """ @@ -70,9 +72,11 @@ def __init__( normalize_images: bool = True, optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, optimizer_kwargs: Optional[Dict[str, Any]] = None, + lstm_hidden_size: int = 64, + n_lstm_layers: int = 1, + shared_lstm: bool = False, ): - hidden_size = 64 - self.lstm_output_dim = hidden_size + self.lstm_output_dim = lstm_hidden_size super().__init__( observation_space, action_space, @@ -93,11 +97,13 @@ def __init__( optimizer_kwargs, ) - num_layers = 1 - self.lstm_actor = nn.LSTM(self.features_dim, hidden_size, num_layers=num_layers) - self.lstm_shape = (num_layers, 1, hidden_size) - # self.lstm_critic = nn.LSTM(self.features_dim, hidden_size, num_layers=num_layers) - self.critic = nn.Linear(self.features_dim, hidden_size) + self.shared_lstm = shared_lstm + self.lstm_actor = nn.LSTM(self.features_dim, lstm_hidden_size, num_layers=n_lstm_layers) + self.lstm_shape = (n_lstm_layers, 1, lstm_hidden_size) + # self.lstm_critic = nn.LSTM(self.features_dim, lstm_hidden_size, num_layers=n_lstm_layers) + self.critic = None + if not self.shared_lstm: + self.critic = nn.Linear(self.features_dim, lstm_hidden_size) # Setup optimizer with initial learning rate self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs) @@ -145,10 +151,10 @@ def _process_sequence( def forward( self, obs: th.Tensor, - lstm_states: Tuple[th.Tensor, th.Tensor], + lstm_states: RNNStates, episode_starts: th.Tensor, deterministic: bool = False, - ) -> Tuple[th.Tensor, th.Tensor, th.Tensor, Tuple[th.Tensor, ...]]: + ) -> Tuple[th.Tensor, th.Tensor, th.Tensor, RNNStates]: """ Forward pass in all the networks (actor and critic) @@ -159,11 +165,15 @@ def forward( # Preprocess the observation if needed features = self.extract_features(obs) # latent_pi, latent_vf = self.mlp_extractor(features) - latent_pi, lstm_states_pi = self._process_sequence(features, lstm_states[0], episode_starts, self.lstm_actor) - # TODO: try re-using LSTM features for value function but using detach - # latent_vf, lstm_states_vf = self._process_sequence(features, lstm_states[1], episode_starts, self.lstm_critic) - lstm_states_vf = None - latent_vf = self.critic(features) + latent_pi, lstm_states_pi = self._process_sequence(features, lstm_states.pi, episode_starts, self.lstm_actor) + # latent_vf, lstm_states_vf = self._process_sequence(features, lstm_states.vf, episode_starts, self.lstm_critic) + # Re-use LSTM features but do not backpropagate + if self.shared_lstm: + latent_vf = latent_pi.detach() + lstm_states_vf = (lstm_states_pi[0].detach(), lstm_states_pi[1].detach()) + else: + latent_vf = self.critic(features) + lstm_states_vf = lstm_states_pi latent_pi = self.mlp_extractor.forward_actor(latent_pi) latent_vf = self.mlp_extractor.forward_critic(latent_vf) @@ -173,7 +183,7 @@ def forward( distribution = self._get_action_dist_from_latent(latent_pi) actions = distribution.get_actions(deterministic=deterministic) log_prob = distribution.log_prob(actions) - return actions, values, log_prob, (lstm_states_pi, lstm_states_vf) + return actions, values, log_prob, RNNStates(lstm_states_pi, lstm_states_vf) def get_distribution( self, @@ -205,8 +215,14 @@ def predict_values( :return: the estimated values. """ features = self.extract_features(obs) + # Use LSTM from the actor + if self.shared_lstm: + latent_pi, _ = self._process_sequence(features, lstm_states, episode_starts, self.lstm_actor) + latent_vf = latent_pi.detach() + else: + latent_vf = self.critic(features) + # latent_vf, _ = self._process_sequence(features, lstm_states, episode_starts, self.lstm_critic) - latent_vf = self.critic(features) latent_vf = self.mlp_extractor.forward_critic(latent_vf) return self.value_net(latent_vf) @@ -214,7 +230,7 @@ def evaluate_actions( self, obs: th.Tensor, actions: th.Tensor, - lstm_states: Tuple[th.Tensor, th.Tensor], + lstm_states: RNNStates, episode_starts: th.Tensor, ) -> Tuple[th.Tensor, th.Tensor, th.Tensor]: """ @@ -229,9 +245,12 @@ def evaluate_actions( # Preprocess the observation if needed features = self.extract_features(obs) # latent_pi, latent_vf = self.mlp_extractor(features) - latent_pi, _ = self._process_sequence(features, lstm_states[0], episode_starts, self.lstm_actor) - # latent_vf, _ = self._process_sequence(features, lstm_states[1], episode_starts, self.lstm_critic) - latent_vf = self.critic(features) + latent_pi, _ = self._process_sequence(features, lstm_states.pi, episode_starts, self.lstm_actor) + # latent_vf, _ = self._process_sequence(features, lstm_states.vf, episode_starts, self.lstm_critic) + if self.shared_lstm: + latent_vf = latent_pi.detach() + else: + latent_vf = self.critic(features) latent_pi = self.mlp_extractor.forward_actor(latent_pi) latent_vf = self.mlp_extractor.forward_critic(latent_vf) @@ -372,6 +391,8 @@ def __init__( normalize_images: bool = True, optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, optimizer_kwargs: Optional[Dict[str, Any]] = None, + lstm_hidden_size: int = 64, + n_lstm_layers: int = 1, ): super().__init__( observation_space, @@ -391,6 +412,8 @@ def __init__( normalize_images, optimizer_class, optimizer_kwargs, + lstm_hidden_size, + n_lstm_layers, ) @@ -447,6 +470,8 @@ def __init__( normalize_images: bool = True, optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, optimizer_kwargs: Optional[Dict[str, Any]] = None, + lstm_hidden_size: int = 64, + n_lstm_layers: int = 1, ): super().__init__( observation_space, @@ -466,4 +491,6 @@ def __init__( normalize_images, optimizer_class, optimizer_kwargs, + lstm_hidden_size, + n_lstm_layers, ) diff --git a/sb3_contrib/common/recurrent/type_aliases.py b/sb3_contrib/common/recurrent/type_aliases.py new file mode 100644 index 00000000..0bf019ca --- /dev/null +++ b/sb3_contrib/common/recurrent/type_aliases.py @@ -0,0 +1,31 @@ +from typing import NamedTuple, Tuple + +import torch as th +from stable_baselines3.common.type_aliases import TensorDict + + +class RNNStates(NamedTuple): + pi: Tuple[th.Tensor, ...] + vf: Tuple[th.Tensor, ...] + + +class RecurrentRolloutBufferSamples(NamedTuple): + observations: th.Tensor + actions: th.Tensor + old_values: th.Tensor + old_log_prob: th.Tensor + advantages: th.Tensor + returns: th.Tensor + lstm_states: RNNStates + episode_starts: th.Tensor + + +class RecurrentDictRolloutBufferSamples(RecurrentRolloutBufferSamples): + observations: TensorDict + actions: th.Tensor + old_values: th.Tensor + old_log_prob: th.Tensor + advantages: th.Tensor + returns: th.Tensor + lstm_states: RNNStates + episode_starts: th.Tensor diff --git a/sb3_contrib/ppo_lstm/ppo_lstm.py b/sb3_contrib/ppo_lstm/ppo_lstm.py index 1bd9b320..700ca737 100644 --- a/sb3_contrib/ppo_lstm/ppo_lstm.py +++ b/sb3_contrib/ppo_lstm/ppo_lstm.py @@ -17,6 +17,7 @@ from sb3_contrib.common.recurrent.buffers import RecurrentDictRolloutBuffer, RecurrentRolloutBuffer from sb3_contrib.common.recurrent.policies import RecurrentActorCriticPolicy +from sb3_contrib.common.recurrent.type_aliases import RNNStates class RecurrentPPO(OnPolicyAlgorithm): @@ -154,7 +155,7 @@ def _setup_model(self) -> None: single_hidden_state_shape = (lstm.num_layers, self.n_envs, lstm.hidden_size) # hidden states for actor and critic - self.lstm_states = ( + self.lstm_states = RNNStates( ( th.zeros(single_hidden_state_shape).to(self.device), th.zeros(single_hidden_state_shape).to(self.device), @@ -305,11 +306,11 @@ def collect_rollouts( ): terminal_obs = self.policy.obs_to_tensor(infos[idx]["terminal_observation"])[0] with th.no_grad(): - # terminal_lstm_state = ( - # lstm_states[1][0][:, idx : idx + 1, :], - # lstm_states[1][1][:, idx : idx + 1, :], - # ) - terminal_lstm_state = None + terminal_lstm_state = ( + lstm_states.vf[0][:, idx : idx + 1, :], + lstm_states.vf[1][:, idx : idx + 1, :], + ) + # terminal_lstm_state = None episode_starts = th.tensor([False]).float().to(self.device) terminal_value = self.policy.predict_values(terminal_obs, terminal_lstm_state, episode_starts)[0] rewards[idx] += self.gamma * terminal_value @@ -321,7 +322,7 @@ def collect_rollouts( self._last_episode_starts, values, log_probs, - lstm_states=(self.lstm_states[0][0].cpu().numpy(), self.lstm_states[0][1].cpu().numpy()), + lstm_states=(self.lstm_states.pi[0].cpu().numpy(), self.lstm_states.pi[1].cpu().numpy()), ) self._last_obs = new_obs From 832093d853b62232859d559c83a8f392eb301d1a Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Sun, 28 Nov 2021 17:59:35 +0100 Subject: [PATCH 17/50] Rename mask to episode_start --- sb3_contrib/common/maskable/policies.py | 8 ++++---- sb3_contrib/common/recurrent/policies.py | 14 +++++++------- sb3_contrib/ppo_mask/ppo_mask.py | 17 ++++++++++------- sb3_contrib/qrdqn/qrdqn.py | 15 +++++++++------ 4 files changed, 30 insertions(+), 24 deletions(-) diff --git a/sb3_contrib/common/maskable/policies.py b/sb3_contrib/common/maskable/policies.py index d7c9522a..4b65ad9b 100644 --- a/sb3_contrib/common/maskable/policies.py +++ b/sb3_contrib/common/maskable/policies.py @@ -228,7 +228,7 @@ def predict( self, observation: Union[np.ndarray, Dict[str, np.ndarray]], state: Optional[np.ndarray] = None, - mask: Optional[np.ndarray] = None, + episode_start: Optional[np.ndarray] = None, deterministic: bool = False, action_masks: Optional[np.ndarray] = None, ) -> Tuple[np.ndarray, Optional[np.ndarray]]: @@ -238,7 +238,7 @@ def predict( :param observation: the input observation :param state: The last states (can be None, used in recurrent policies) - :param mask: The last masks (can be None, used in recurrent policies) + :param episode_start: The last masks (can be None, used in recurrent policies) :param deterministic: Whether or not to return deterministic actions. :param action_masks: Action masks to apply to the action distribution :return: the model's action and the next state @@ -247,8 +247,8 @@ def predict( # TODO (GH/1): add support for RNN policies # if state is None: # state = self.initial_state - # if mask is None: - # mask = [False for _ in range(self.n_envs)] + # if episode_start is None: + # episode_start = [False for _ in range(self.n_envs)] # Switch to eval mode (this affects batch norm / dropout) self.set_training_mode(False) diff --git a/sb3_contrib/common/recurrent/policies.py b/sb3_contrib/common/recurrent/policies.py index 255b72fa..95f16ce2 100644 --- a/sb3_contrib/common/recurrent/policies.py +++ b/sb3_contrib/common/recurrent/policies.py @@ -282,16 +282,16 @@ def predict( self, observation: Union[np.ndarray, Dict[str, np.ndarray]], state: Optional[np.ndarray] = None, - mask: Optional[np.ndarray] = None, + episode_start: Optional[np.ndarray] = None, deterministic: bool = False, ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]: """ - Get the policy action and state from an observation (and optional state). + Get the policy action and state from an observation (and optional hidden state). Includes sugar-coating to handle different observations (e.g. normalizing images). :param observation: the input observation :param state: The last hidden states (can be None, used in recurrent policies) - :param mask: The last masks (can be None, used in recurrent policies) + :param episode_start: The last masks (can be None, used in recurrent policies) this correspond to beginning of episodes, where the hidden states of the RNN must be reset. :param deterministic: Whether or not to return deterministic actions. @@ -309,14 +309,14 @@ def predict( state = np.concatenate([np.zeros(self.lstm_shape) for _ in range(n_envs)], axis=1) state = (state, state) - if mask is None: - mask = np.array([False for _ in range(n_envs)]) + if episode_start is None: + episode_start = np.array([False for _ in range(n_envs)]) with th.no_grad(): # Convert to PyTorch tensors states = th.tensor(state[0]).float().to(self.device), th.tensor(state[1]).float().to(self.device) - mask = th.tensor(mask).float().to(self.device) - actions, states = self._predict(observation, lstm_states=states, episode_starts=mask, deterministic=deterministic) + episode_starts = th.tensor(episode_start).float().to(self.device) + actions, states = self._predict(observation, lstm_states=states, episode_starts=episode_starts, deterministic=deterministic) states = (states[0].cpu().numpy(), states[1].cpu().numpy()) # Convert to numpy diff --git a/sb3_contrib/ppo_mask/ppo_mask.py b/sb3_contrib/ppo_mask/ppo_mask.py index 4648af3f..2fafce9e 100644 --- a/sb3_contrib/ppo_mask/ppo_mask.py +++ b/sb3_contrib/ppo_mask/ppo_mask.py @@ -366,21 +366,24 @@ def predict( self, observation: np.ndarray, state: Optional[np.ndarray] = None, - mask: Optional[np.ndarray] = None, + episode_start: Optional[np.ndarray] = None, deterministic: bool = False, action_masks: Optional[np.ndarray] = None, ) -> Tuple[np.ndarray, Optional[np.ndarray]]: """ - Get the model's action(s) from an observation. + Get the policy action and state from an observation (and optional hidden state). + Includes sugar-coating to handle different observations (e.g. normalizing images). :param observation: the input observation - :param state: The last states (can be None, used in recurrent policies) - :param mask: The last masks (can be None, used in recurrent policies) + :param state: The last hidden states (can be None, used in recurrent policies) + :param episode_start: The last masks (can be None, used in recurrent policies) + this correspond to beginning of episodes, + where the hidden states of the RNN must be reset. :param deterministic: Whether or not to return deterministic actions. - :param action_masks: Action masks to apply to the action distribution. - :return: the model's action and the next state (used in recurrent policies) + :return: the model's action and the next hidden state + (used in recurrent policies) """ - return self.policy.predict(observation, state, mask, deterministic, action_masks=action_masks) + return self.policy.predict(observation, state, episode_start, deterministic, action_masks=action_masks) def train(self) -> None: """ diff --git a/sb3_contrib/qrdqn/qrdqn.py b/sb3_contrib/qrdqn/qrdqn.py index ad6016e0..53d7c8af 100644 --- a/sb3_contrib/qrdqn/qrdqn.py +++ b/sb3_contrib/qrdqn/qrdqn.py @@ -207,17 +207,20 @@ def predict( self, observation: np.ndarray, state: Optional[np.ndarray] = None, - mask: Optional[np.ndarray] = None, + episode_start: Optional[np.ndarray] = None, deterministic: bool = False, ) -> Tuple[np.ndarray, Optional[np.ndarray]]: """ - Overrides the base_class predict function to include epsilon-greedy exploration. + Get the policy action and state from an observation (and optional hidden state). + Includes sugar-coating to handle different observations (e.g. normalizing images). :param observation: the input observation - :param state: The last states (can be None, used in recurrent policies) - :param mask: The last masks (can be None, used in recurrent policies) + :param state: The last hidden states (can be None, used in recurrent policies) + :param episode_start: The last masks (can be None, used in recurrent policies) + this correspond to beginning of episodes, + where the hidden states of the RNN must be reset. :param deterministic: Whether or not to return deterministic actions. - :return: the model's action and the next state + :return: the model's action and the next hidden state (used in recurrent policies) """ if not deterministic and np.random.rand() < self.exploration_rate: @@ -230,7 +233,7 @@ def predict( else: action = np.array(self.action_space.sample()) else: - action, state = self.policy.predict(observation, state, mask, deterministic) + action, state = self.policy.predict(observation, state, episode_start, deterministic) return action, state def learn( From 2a9c956fe08bb0d756814876c0551fe5e95ba0b4 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Sun, 28 Nov 2021 18:06:31 +0100 Subject: [PATCH 18/50] Fix type hint --- sb3_contrib/common/maskable/policies.py | 6 +++--- sb3_contrib/common/recurrent/policies.py | 6 ++++-- sb3_contrib/ppo_mask/ppo_mask.py | 4 ++-- sb3_contrib/qrdqn/qrdqn.py | 4 ++-- 4 files changed, 11 insertions(+), 9 deletions(-) diff --git a/sb3_contrib/common/maskable/policies.py b/sb3_contrib/common/maskable/policies.py index 4b65ad9b..0d60489d 100644 --- a/sb3_contrib/common/maskable/policies.py +++ b/sb3_contrib/common/maskable/policies.py @@ -227,11 +227,11 @@ def _predict( def predict( self, observation: Union[np.ndarray, Dict[str, np.ndarray]], - state: Optional[np.ndarray] = None, + state: Optional[Tuple[np.ndarray, ...]] = None, episode_start: Optional[np.ndarray] = None, deterministic: bool = False, action_masks: Optional[np.ndarray] = None, - ) -> Tuple[np.ndarray, Optional[np.ndarray]]: + ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]: """ Get the policy action and state from an observation (and optional state). Includes sugar-coating to handle different observations (e.g. normalizing images). @@ -274,7 +274,7 @@ def predict( raise ValueError("Error: The environment must be vectorized when using recurrent policies.") actions = actions[0] - return actions, state + return actions, None def evaluate_actions( self, diff --git a/sb3_contrib/common/recurrent/policies.py b/sb3_contrib/common/recurrent/policies.py index 95f16ce2..ae419bef 100644 --- a/sb3_contrib/common/recurrent/policies.py +++ b/sb3_contrib/common/recurrent/policies.py @@ -281,7 +281,7 @@ def _predict( def predict( self, observation: Union[np.ndarray, Dict[str, np.ndarray]], - state: Optional[np.ndarray] = None, + state: Optional[Tuple[np.ndarray, ...]] = None, episode_start: Optional[np.ndarray] = None, deterministic: bool = False, ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]: @@ -316,7 +316,9 @@ def predict( # Convert to PyTorch tensors states = th.tensor(state[0]).float().to(self.device), th.tensor(state[1]).float().to(self.device) episode_starts = th.tensor(episode_start).float().to(self.device) - actions, states = self._predict(observation, lstm_states=states, episode_starts=episode_starts, deterministic=deterministic) + actions, states = self._predict( + observation, lstm_states=states, episode_starts=episode_starts, deterministic=deterministic + ) states = (states[0].cpu().numpy(), states[1].cpu().numpy()) # Convert to numpy diff --git a/sb3_contrib/ppo_mask/ppo_mask.py b/sb3_contrib/ppo_mask/ppo_mask.py index 2fafce9e..8c8a4c4e 100644 --- a/sb3_contrib/ppo_mask/ppo_mask.py +++ b/sb3_contrib/ppo_mask/ppo_mask.py @@ -365,11 +365,11 @@ def collect_rollouts( def predict( self, observation: np.ndarray, - state: Optional[np.ndarray] = None, + state: Optional[Tuple[np.ndarray, ...]] = None, episode_start: Optional[np.ndarray] = None, deterministic: bool = False, action_masks: Optional[np.ndarray] = None, - ) -> Tuple[np.ndarray, Optional[np.ndarray]]: + ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]: """ Get the policy action and state from an observation (and optional hidden state). Includes sugar-coating to handle different observations (e.g. normalizing images). diff --git a/sb3_contrib/qrdqn/qrdqn.py b/sb3_contrib/qrdqn/qrdqn.py index 53d7c8af..b2548ec6 100644 --- a/sb3_contrib/qrdqn/qrdqn.py +++ b/sb3_contrib/qrdqn/qrdqn.py @@ -206,10 +206,10 @@ def train(self, gradient_steps: int, batch_size: int = 100) -> None: def predict( self, observation: np.ndarray, - state: Optional[np.ndarray] = None, + state: Optional[Tuple[np.ndarray, ...]] = None, episode_start: Optional[np.ndarray] = None, deterministic: bool = False, - ) -> Tuple[np.ndarray, Optional[np.ndarray]]: + ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]: """ Get the policy action and state from an observation (and optional hidden state). Includes sugar-coating to handle different observations (e.g. normalizing images). From 15c080ac33e1df6337fbddc7ea8ebf74225ffb46 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Sun, 28 Nov 2021 20:17:07 +0100 Subject: [PATCH 19/50] Enable LSTM for critic --- sb3_contrib/common/recurrent/buffers.py | 94 +++++++++++++----------- sb3_contrib/common/recurrent/policies.py | 33 ++++++--- sb3_contrib/ppo_lstm/ppo_lstm.py | 17 +++-- tests/test_lstm.py | 13 ++++ 4 files changed, 98 insertions(+), 59 deletions(-) diff --git a/sb3_contrib/common/recurrent/buffers.py b/sb3_contrib/common/recurrent/buffers.py index 8644fb1a..d31f429a 100644 --- a/sb3_contrib/common/recurrent/buffers.py +++ b/sb3_contrib/common/recurrent/buffers.py @@ -37,26 +37,30 @@ def __init__( gae_lambda: float = 1, gamma: float = 0.99, n_envs: int = 1, - sampling_style: str = "default", # "defaults" or "per_env" + sampling_strategy: str = "default", # "default" or "per_env" ): self.lstm_states = lstm_states # self.dones = None self.initial_lstm_states = None - self.sampling_style = sampling_style + self.sampling_strategy = sampling_strategy self.starts, self.ends = None, None super().__init__(buffer_size, observation_space, action_space, device, gae_lambda, gamma, n_envs) def reset(self): super().reset() - self.hidden_states = np.zeros_like(self.lstm_states[0]) - self.cell_states = np.zeros_like(self.lstm_states[1]) + self.hidden_states_pi = np.zeros_like(self.lstm_states[0]) + self.cell_states_pi = np.zeros_like(self.lstm_states[1]) + self.hidden_states_vf = np.zeros_like(self.lstm_states[0]) + self.cell_states_vf = np.zeros_like(self.lstm_states[1]) - def add(self, *args, lstm_states: Tuple[np.ndarray, np.ndarray], **kwargs) -> None: + def add(self, *args, lstm_states: RNNStates, **kwargs) -> None: """ :param hidden_states: LSTM cell and hidden state """ - self.hidden_states[self.pos] = np.array(lstm_states[0]) - self.cell_states[self.pos] = np.array(lstm_states[1]) + self.hidden_states_pi[self.pos] = np.array(lstm_states.pi[0].cpu().numpy()) + self.cell_states_pi[self.pos] = np.array(lstm_states.pi[1].cpu().numpy()) + self.hidden_states_vf[self.pos] = np.array(lstm_states.vf[0].cpu().numpy()) + self.cell_states_vf[self.pos] = np.array(lstm_states.vf[1].cpu().numpy()) super().add(*args, **kwargs) @@ -67,8 +71,8 @@ def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentRolloutBuf if not self.generator_ready: # hidden_state_shape = (self.n_steps, lstm.num_layers, self.n_envs, lstm.hidden_size) # swap first to (self.n_steps, self.n_envs, lstm.num_layers, lstm.hidden_size) - self.hidden_states = self.hidden_states.swapaxes(1, 2) - self.cell_states = self.cell_states.swapaxes(1, 2) + for tensor in ["hidden_states_pi", "cell_states_pi", "hidden_states_vf", "cell_states_vf"]: + self.__dict__[tensor] = self.__dict__[tensor].swapaxes(1, 2) for tensor in [ "observations", @@ -77,8 +81,10 @@ def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentRolloutBuf "log_probs", "advantages", "returns", - "hidden_states", - "cell_states", + "hidden_states_pi", + "cell_states_pi", + "hidden_states_vf", + "cell_states_vf", "episode_starts", ]: self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor]) @@ -88,7 +94,7 @@ def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentRolloutBuf if batch_size is None: batch_size = self.buffer_size * self.n_envs - if self.sampling_style == "default": + if self.sampling_strategy == "default": # No shuffling # indices = np.arange(self.buffer_size * self.n_envs) # Trick to shuffle a bit: keep the sequence order @@ -126,18 +132,17 @@ def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentRolloutBuf mini_batch_env_indices = env_indices[start_env_idx:end_env_idx] batch_inds = flat_indices[mini_batch_env_indices].ravel() # lstm_states_pi = ( - # self.hidden_states[:, :, mini_batch_env_indices, :][0], - # self.cell_states[:, :, mini_batch_env_indices, :][0], + # self.hidden_states_pi[:, :, mini_batch_env_indices, :][0], + # self.cell_states_pi[:, :, mini_batch_env_indices, :][0], # ) lstm_states_pi = ( self.initial_lstm_states.pi[0][:, mini_batch_env_indices].clone(), self.initial_lstm_states.pi[1][:, mini_batch_env_indices].clone(), ) - # lstm_states_vf = ( - # self.initial_lstm_states[1][0][:, mini_batch_env_indices].clone(), - # self.initial_lstm_states[1][1][:, mini_batch_env_indices].clone(), - # ) - lstm_states_vf = lstm_states_pi + lstm_states_vf = ( + self.initial_lstm_states.vf[0][:, mini_batch_env_indices].clone(), + self.initial_lstm_states.vf[1][:, mini_batch_env_indices].clone(), + ) yield RecurrentRolloutBufferSamples( observations=self.to_torch(self.observations[batch_inds]), @@ -167,19 +172,23 @@ def _get_samples( self.starts = np.where(seq_start == True)[0] # noqa: E712 self.ends = np.concatenate([(self.starts - 1)[1:], np.array([len(batch_inds)])]) - n_layers = self.hidden_states.shape[1] + n_layers = self.hidden_states_pi.shape[1] n_seq = len(self.starts) max_length = self.pad(self.actions[batch_inds]).shape[0] # TODO: output mask to not backpropagate everywhere padded_batch_size = n_seq * max_length lstm_states_pi = ( # (n_steps, n_layers, n_envs, dim) -> (n_layers, n_seq, dim) - self.hidden_states[batch_inds][seq_start == True].reshape(n_layers, n_seq, -1), # noqa: E712 - self.cell_states[batch_inds][seq_start == True].reshape(n_layers, n_seq, -1), # noqa: E712 + self.hidden_states_pi[batch_inds][seq_start == True].reshape(n_layers, n_seq, -1), # noqa: E712 + self.cell_states_pi[batch_inds][seq_start == True].reshape(n_layers, n_seq, -1), # noqa: E712 + ) + lstm_states_vf = ( + # (n_steps, n_layers, n_envs, dim) -> (n_layers, n_seq, dim) + self.hidden_states_vf[batch_inds][seq_start == True].reshape(n_layers, n_seq, -1), # noqa: E712 + self.cell_states_vf[batch_inds][seq_start == True].reshape(n_layers, n_seq, -1), # noqa: E712 ) lstm_states_pi = (self.to_torch(lstm_states_pi[0]), self.to_torch(lstm_states_pi[1])) - - lstm_states_vf = lstm_states_pi + lstm_states_vf = (self.to_torch(lstm_states_vf[0]), self.to_torch(lstm_states_vf[1])) return RecurrentRolloutBufferSamples( observations=self.pad(self.observations[batch_inds]).swapaxes(0, 1).reshape((padded_batch_size,) + self.obs_shape), @@ -228,16 +237,17 @@ def __init__( gae_lambda: float = 1, gamma: float = 0.99, n_envs: int = 1, + sampling_strategy: str = "default", # "default" or "per_env" ): super(RecurrentDictRolloutBuffer, self).__init__( buffer_size, observation_space, action_space, device, gae_lambda, gamma, n_envs=n_envs ) self.lstm_states = lstm_states - self.dones = None + self.sampling_strategy = sampling_strategy def reset(self): - self.hidden_states = np.zeros_like(self.lstm_states[0]) - self.cell_states = np.zeros_like(self.lstm_states[1]) + self.hidden_states_pi = np.zeros_like(self.lstm_states[0]) + self.cell_states_pi = np.zeros_like(self.lstm_states[1]) self.dones = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) super().reset() @@ -245,8 +255,8 @@ def add(self, *args, lstm_states: Tuple[np.ndarray, np.ndarray], dones: np.ndarr """ :param hidden_states: LSTM cell and hidden state """ - self.hidden_states[self.pos] = np.array(lstm_states[0]) - self.cell_states[self.pos] = np.array(lstm_states[1]) + self.hidden_states_pi[self.pos] = np.array(lstm_states[0]) + self.cell_states_pi[self.pos] = np.array(lstm_states[1]) self.dones[self.pos] = np.array(dones) super().add(*args, **kwargs) @@ -260,24 +270,22 @@ def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentDictRollou if not self.generator_ready: # hidden_state_shape = (self.n_steps, lstm.num_layers, self.n_envs, lstm.hidden_size) # swap first to (self.n_steps, self.n_envs, lstm.num_layers, lstm.hidden_size) - self.hidden_states = self.hidden_states.swapaxes(1, 2) - self.cell_states = self.cell_states.swapaxes(1, 2) - - for key, obs in self.observations.items(): - self.observations[key] = self.swap_and_flatten(obs) + for tensor in ["hidden_states_pi", "cell_states_pi", "hidden_states_vf", "cell_states_vf"]: + self.__dict__[tensor] = self.__dict__[tensor].swapaxes(1, 2) - _tensor_names = [ + for tensor in [ + "observations", "actions", "values", "log_probs", "advantages", "returns", - "hidden_states", - "cell_states", - "dones", - ] - - for tensor in _tensor_names: + "hidden_states_pi", + "cell_states_pi", + "hidden_states_vf", + "cell_states_vf", + "episode_starts", + ]: self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor]) self.generator_ready = True @@ -299,6 +307,8 @@ def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = Non old_log_prob=self.to_torch(self.log_probs[batch_inds].flatten()), advantages=self.to_torch(self.advantages[batch_inds].flatten()), returns=self.to_torch(self.returns[batch_inds].flatten()), - lstm_states=RNNStates(self.to_torch(self.hidden_states[batch_inds]), self.to_torch(self.cell_states[batch_inds])), + lstm_states=RNNStates( + self.to_torch(self.hidden_states_pi[batch_inds]), self.to_torch(self.cell_states_pi[batch_inds]) + ), episode_starts=self.to_torch(self.episode_starts[batch_inds].flatten()), ) diff --git a/sb3_contrib/common/recurrent/policies.py b/sb3_contrib/common/recurrent/policies.py index ae419bef..6d0bbf31 100644 --- a/sb3_contrib/common/recurrent/policies.py +++ b/sb3_contrib/common/recurrent/policies.py @@ -75,6 +75,7 @@ def __init__( lstm_hidden_size: int = 64, n_lstm_layers: int = 1, shared_lstm: bool = False, + enable_critic_lstm: bool = False, ): self.lstm_output_dim = lstm_hidden_size super().__init__( @@ -98,12 +99,20 @@ def __init__( ) self.shared_lstm = shared_lstm + self.enable_critic_lstm = enable_critic_lstm self.lstm_actor = nn.LSTM(self.features_dim, lstm_hidden_size, num_layers=n_lstm_layers) self.lstm_shape = (n_lstm_layers, 1, lstm_hidden_size) - # self.lstm_critic = nn.LSTM(self.features_dim, lstm_hidden_size, num_layers=n_lstm_layers) self.critic = None - if not self.shared_lstm: + self.lstm_critic = None + assert not ( + self.shared_lstm and self.enable_critic_lstm + ), "You must choose between shared LSTM, seperate or no LSTM for the critic" + if not (self.shared_lstm or self.enable_critic_lstm): self.critic = nn.Linear(self.features_dim, lstm_hidden_size) + + if self.enable_critic_lstm: + self.lstm_critic = nn.LSTM(self.features_dim, lstm_hidden_size, num_layers=n_lstm_layers) + # Setup optimizer with initial learning rate self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs) @@ -166,9 +175,10 @@ def forward( features = self.extract_features(obs) # latent_pi, latent_vf = self.mlp_extractor(features) latent_pi, lstm_states_pi = self._process_sequence(features, lstm_states.pi, episode_starts, self.lstm_actor) - # latent_vf, lstm_states_vf = self._process_sequence(features, lstm_states.vf, episode_starts, self.lstm_critic) - # Re-use LSTM features but do not backpropagate - if self.shared_lstm: + if self.lstm_critic is not None: + latent_vf, lstm_states_vf = self._process_sequence(features, lstm_states.vf, episode_starts, self.lstm_critic) + elif self.shared_lstm: + # Re-use LSTM features but do not backpropagate latent_vf = latent_pi.detach() lstm_states_vf = (lstm_states_pi[0].detach(), lstm_states_pi[1].detach()) else: @@ -215,8 +225,10 @@ def predict_values( :return: the estimated values. """ features = self.extract_features(obs) - # Use LSTM from the actor - if self.shared_lstm: + if self.lstm_critic is not None: + latent_vf, lstm_states_vf = self._process_sequence(features, lstm_states, episode_starts, self.lstm_critic) + elif self.shared_lstm: + # Use LSTM from the actor latent_pi, _ = self._process_sequence(features, lstm_states, episode_starts, self.lstm_actor) latent_vf = latent_pi.detach() else: @@ -244,10 +256,11 @@ def evaluate_actions( """ # Preprocess the observation if needed features = self.extract_features(obs) - # latent_pi, latent_vf = self.mlp_extractor(features) latent_pi, _ = self._process_sequence(features, lstm_states.pi, episode_starts, self.lstm_actor) - # latent_vf, _ = self._process_sequence(features, lstm_states.vf, episode_starts, self.lstm_critic) - if self.shared_lstm: + + if self.lstm_critic is not None: + latent_vf, _ = self._process_sequence(features, lstm_states.vf, episode_starts, self.lstm_critic) + elif self.shared_lstm: latent_vf = latent_pi.detach() else: latent_vf = self.critic(features) diff --git a/sb3_contrib/ppo_lstm/ppo_lstm.py b/sb3_contrib/ppo_lstm/ppo_lstm.py index 700ca737..f20e1f9f 100644 --- a/sb3_contrib/ppo_lstm/ppo_lstm.py +++ b/sb3_contrib/ppo_lstm/ppo_lstm.py @@ -82,6 +82,7 @@ def __init__( use_sde: bool = False, sde_sample_freq: int = -1, target_kl: Optional[float] = None, + sampling_strategy: str = "default", # "default" or "per_env" tensorboard_log: Optional[str] = None, create_eval_env: bool = False, policy_kwargs: Optional[Dict[str, Any]] = None, @@ -123,7 +124,8 @@ def __init__( self.clip_range = clip_range self.clip_range_vf = clip_range_vf self.target_kl = target_kl - self.lstm_states = None + self._last_lstm_states = None + self.sampling_strategy = sampling_strategy if _init_setup_model: self._setup_model() @@ -155,7 +157,7 @@ def _setup_model(self) -> None: single_hidden_state_shape = (lstm.num_layers, self.n_envs, lstm.hidden_size) # hidden states for actor and critic - self.lstm_states = RNNStates( + self._last_lstm_states = RNNStates( ( th.zeros(single_hidden_state_shape).to(self.device), th.zeros(single_hidden_state_shape).to(self.device), @@ -175,6 +177,7 @@ def _setup_model(self) -> None: gamma=self.gamma, gae_lambda=self.gae_lambda, n_envs=self.n_envs, + sampling_strategy=self.sampling_strategy, ) # Initialize schedules for policy/value clipping @@ -258,8 +261,8 @@ def collect_rollouts( callback.on_rollout_start() - rollout_buffer.initial_lstm_states = deepcopy(self.lstm_states) - lstm_states = deepcopy(self.lstm_states) + rollout_buffer.initial_lstm_states = deepcopy(self._last_lstm_states) + lstm_states = deepcopy(self._last_lstm_states) while n_steps < n_rollout_steps: if self.use_sde and self.sde_sample_freq > 0 and n_steps % self.sde_sample_freq == 0: @@ -322,17 +325,17 @@ def collect_rollouts( self._last_episode_starts, values, log_probs, - lstm_states=(self.lstm_states.pi[0].cpu().numpy(), self.lstm_states.pi[1].cpu().numpy()), + lstm_states=self._last_lstm_states, ) self._last_obs = new_obs self._last_episode_starts = dones - self.lstm_states = lstm_states + self._last_lstm_states = lstm_states with th.no_grad(): # Compute value for the last timestep episode_starts = th.tensor(dones).float().to(self.device) - values = self.policy.predict_values(obs_as_tensor(new_obs, self.device), lstm_states[1], episode_starts) + values = self.policy.predict_values(obs_as_tensor(new_obs, self.device), lstm_states.vf, episode_starts) rollout_buffer.compute_returns_and_advantage(last_values=values, dones=dones) diff --git a/tests/test_lstm.py b/tests/test_lstm.py index b981bf49..7eade194 100644 --- a/tests/test_lstm.py +++ b/tests/test_lstm.py @@ -50,6 +50,19 @@ def test_cnn(): model.learn(total_timesteps=32) +@pytest.mark.parametrize("policy_kwargs", [{}, dict(shared_lstm=True), dict(enable_critic_lstm=True, lstm_hidden_size=4)]) +def test_policy_kwargs(policy_kwargs): + model = RecurrentPPO( + "MlpLstmPolicy", + "CartPole-v1", + n_steps=16, + seed=0, + policy_kwargs=policy_kwargs, + ) + + model.learn(total_timesteps=32) + + @pytest.mark.parametrize("env", ["Pendulum-v0", "CartPole-v1"]) def test_run(env): model = RecurrentPPO( From 0d304aaf32841202f294699df542e56d729e82c3 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Sun, 28 Nov 2021 20:26:18 +0100 Subject: [PATCH 20/50] Clean code --- sb3_contrib/common/recurrent/policies.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sb3_contrib/common/recurrent/policies.py b/sb3_contrib/common/recurrent/policies.py index 6d0bbf31..7198354c 100644 --- a/sb3_contrib/common/recurrent/policies.py +++ b/sb3_contrib/common/recurrent/policies.py @@ -107,6 +107,7 @@ def __init__( assert not ( self.shared_lstm and self.enable_critic_lstm ), "You must choose between shared LSTM, seperate or no LSTM for the critic" + if not (self.shared_lstm or self.enable_critic_lstm): self.critic = nn.Linear(self.features_dim, lstm_hidden_size) @@ -234,7 +235,6 @@ def predict_values( else: latent_vf = self.critic(features) - # latent_vf, _ = self._process_sequence(features, lstm_states, episode_starts, self.lstm_critic) latent_vf = self.mlp_extractor.forward_critic(latent_vf) return self.value_net(latent_vf) From 1dc78b4d8bceac2c83564fb5b80c3dc8f5bbefdc Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Sun, 28 Nov 2021 20:42:59 +0100 Subject: [PATCH 21/50] Fix for CNN LSTM --- sb3_contrib/common/recurrent/policies.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/sb3_contrib/common/recurrent/policies.py b/sb3_contrib/common/recurrent/policies.py index 7198354c..45db3e0f 100644 --- a/sb3_contrib/common/recurrent/policies.py +++ b/sb3_contrib/common/recurrent/policies.py @@ -408,6 +408,7 @@ def __init__( optimizer_kwargs: Optional[Dict[str, Any]] = None, lstm_hidden_size: int = 64, n_lstm_layers: int = 1, + enable_critic_lstm: bool = False, ): super().__init__( observation_space, @@ -429,6 +430,7 @@ def __init__( optimizer_kwargs, lstm_hidden_size, n_lstm_layers, + enable_critic_lstm, ) @@ -487,6 +489,7 @@ def __init__( optimizer_kwargs: Optional[Dict[str, Any]] = None, lstm_hidden_size: int = 64, n_lstm_layers: int = 1, + enable_critic_lstm: bool = False, ): super().__init__( observation_space, @@ -508,4 +511,5 @@ def __init__( optimizer_kwargs, lstm_hidden_size, n_lstm_layers, + enable_critic_lstm, ) From deaa7b4a6e8b7649ee95b4f855c3ad6aa2a8f24b Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Sun, 28 Nov 2021 20:58:20 +0100 Subject: [PATCH 22/50] Fix sampling with n_layers > 1 --- sb3_contrib/common/recurrent/buffers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sb3_contrib/common/recurrent/buffers.py b/sb3_contrib/common/recurrent/buffers.py index d31f429a..8e681ddd 100644 --- a/sb3_contrib/common/recurrent/buffers.py +++ b/sb3_contrib/common/recurrent/buffers.py @@ -166,7 +166,7 @@ def _get_samples( env: Optional[VecNormalize] = None, ) -> RecurrentRolloutBufferSamples: # Create sequence if env change too - seq_start = np.logical_or(self.episode_starts[batch_inds], env_change[batch_inds]) + seq_start = np.logical_or(self.episode_starts[batch_inds], env_change[batch_inds]).flatten() # First index is always the beginning of a sequence seq_start[0] = True self.starts = np.where(seq_start == True)[0] # noqa: E712 From ced6aee532e7570a93ec89ab0bf99d25d762743d Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Mon, 29 Nov 2021 12:31:05 +0100 Subject: [PATCH 23/50] Add std logger --- sb3_contrib/ppo_lstm/ppo_lstm.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/sb3_contrib/ppo_lstm/ppo_lstm.py b/sb3_contrib/ppo_lstm/ppo_lstm.py index f20e1f9f..2e8ba40c 100644 --- a/sb3_contrib/ppo_lstm/ppo_lstm.py +++ b/sb3_contrib/ppo_lstm/ppo_lstm.py @@ -463,6 +463,9 @@ def train(self) -> None: self.logger.record("train/clip_fraction", np.mean(clip_fractions)) self.logger.record("train/loss", loss.item()) self.logger.record("train/explained_variance", explained_var) + if hasattr(self.policy, "log_std"): + self.logger.record("train/std", th.exp(self.policy.log_std).mean().item()) + self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard") self.logger.record("train/clip_range", clip_range) if self.clip_range_vf is not None: From b81fdffe65f1935bda29c143f70026c3b86e5d1a Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Tue, 30 Nov 2021 10:00:24 +0100 Subject: [PATCH 24/50] Update wording --- sb3_contrib/common/maskable/policies.py | 2 +- sb3_contrib/common/recurrent/policies.py | 2 +- sb3_contrib/ppo_mask/ppo_mask.py | 2 +- sb3_contrib/qrdqn/qrdqn.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/sb3_contrib/common/maskable/policies.py b/sb3_contrib/common/maskable/policies.py index 0d60489d..c88b606b 100644 --- a/sb3_contrib/common/maskable/policies.py +++ b/sb3_contrib/common/maskable/policies.py @@ -233,7 +233,7 @@ def predict( action_masks: Optional[np.ndarray] = None, ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]: """ - Get the policy action and state from an observation (and optional state). + Get the policy action from an observation (and optional hidden state). Includes sugar-coating to handle different observations (e.g. normalizing images). :param observation: the input observation diff --git a/sb3_contrib/common/recurrent/policies.py b/sb3_contrib/common/recurrent/policies.py index 45db3e0f..6bf605f3 100644 --- a/sb3_contrib/common/recurrent/policies.py +++ b/sb3_contrib/common/recurrent/policies.py @@ -299,7 +299,7 @@ def predict( deterministic: bool = False, ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]: """ - Get the policy action and state from an observation (and optional hidden state). + Get the policy action from an observation (and optional hidden state). Includes sugar-coating to handle different observations (e.g. normalizing images). :param observation: the input observation diff --git a/sb3_contrib/ppo_mask/ppo_mask.py b/sb3_contrib/ppo_mask/ppo_mask.py index 8c8a4c4e..1edfd884 100644 --- a/sb3_contrib/ppo_mask/ppo_mask.py +++ b/sb3_contrib/ppo_mask/ppo_mask.py @@ -371,7 +371,7 @@ def predict( action_masks: Optional[np.ndarray] = None, ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]: """ - Get the policy action and state from an observation (and optional hidden state). + Get the policy action from an observation (and optional hidden state). Includes sugar-coating to handle different observations (e.g. normalizing images). :param observation: the input observation diff --git a/sb3_contrib/qrdqn/qrdqn.py b/sb3_contrib/qrdqn/qrdqn.py index b2548ec6..5416ecb0 100644 --- a/sb3_contrib/qrdqn/qrdqn.py +++ b/sb3_contrib/qrdqn/qrdqn.py @@ -211,7 +211,7 @@ def predict( deterministic: bool = False, ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]: """ - Get the policy action and state from an observation (and optional hidden state). + Get the policy action from an observation (and optional hidden state). Includes sugar-coating to handle different observations (e.g. normalizing images). :param observation: the input observation From c9c0b4e96b5d17df6778993ff2c1e5d2ef86c340 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Mon, 27 Dec 2021 14:58:47 +0100 Subject: [PATCH 25/50] Rename and add dict obs support --- docs/guide/algos.rst | 6 +- docs/index.rst | 1 + docs/misc/changelog.rst | 1 + docs/modules/ppo_mask.rst | 2 +- docs/modules/ppo_recurrent.rst | 127 ++++++++++++++++++ sb3_contrib/__init__.py | 2 +- sb3_contrib/common/recurrent/buffers.py | 103 ++++++++++---- sb3_contrib/common/recurrent/policies.py | 62 ++++++--- sb3_contrib/ppo_lstm/__init__.py | 2 - sb3_contrib/ppo_lstm/policies.py | 12 -- sb3_contrib/ppo_recurrent/__init__.py | 2 + sb3_contrib/ppo_recurrent/policies.py | 15 +++ .../ppo_recurrent.py} | 0 setup.cfg | 2 +- 14 files changed, 275 insertions(+), 62 deletions(-) create mode 100644 docs/modules/ppo_recurrent.rst delete mode 100644 sb3_contrib/ppo_lstm/__init__.py delete mode 100644 sb3_contrib/ppo_lstm/policies.py create mode 100644 sb3_contrib/ppo_recurrent/__init__.py create mode 100644 sb3_contrib/ppo_recurrent/policies.py rename sb3_contrib/{ppo_lstm/ppo_lstm.py => ppo_recurrent/ppo_recurrent.py} (100%) diff --git a/docs/guide/algos.rst b/docs/guide/algos.rst index 879a84e1..61ca9839 100644 --- a/docs/guide/algos.rst +++ b/docs/guide/algos.rst @@ -8,8 +8,10 @@ along with some useful characteristics: support for discrete/continuous actions, ============ =========== ============ ================= =============== ================ Name ``Box`` ``Discrete`` ``MultiDiscrete`` ``MultiBinary`` Multi Processing ============ =========== ============ ================= =============== ================ -TQC ✔️ ❌ ❌ ❌ ✔️ -QR-DQN ️❌ ️✔️ ❌ ❌ ✔️ +MaskablePPO ❌ ✔️ ✔️ ✔️ ✔️ +RecurrentPPO ✔️ ✔️ ✔️ ✔️ ✔️ +QR-DQN ❌ ✔️ ❌ ❌ ✔️ +TQC ✔️ ❌ ❌ ❌ ✔️ ============ =========== ============ ================= =============== ================ diff --git a/docs/index.rst b/docs/index.rst index 8e37c71b..3b4540cc 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -34,6 +34,7 @@ RL Baselines3 Zoo also offers a simple interface to train, evaluate agents and d modules/tqc modules/qrdqn modules/ppo_mask + modules/ppo_recurrent .. toctree:: :maxdepth: 1 diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 0723975f..d80ca21e 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -16,6 +16,7 @@ Breaking Changes: New Features: ^^^^^^^^^^^^^ - Added experimental support to train off-policy algorithms with multiple envs (note: ``HerReplayBuffer`` currently not supported) +- Added ``RecurrentPPO`` Bug Fixes: ^^^^^^^^^^ diff --git a/docs/modules/ppo_mask.rst b/docs/modules/ppo_mask.rst index 9580ff38..a43f5969 100644 --- a/docs/modules/ppo_mask.rst +++ b/docs/modules/ppo_mask.rst @@ -5,7 +5,7 @@ Maskable PPO ============ -Implementation of `invalid action masking `_ for the Proximal Policy Optimization(PPO) +Implementation of `invalid action masking `_ for the Proximal Policy Optimization (PPO) algorithm. Other than adding support for action masking, the behavior is the same as in SB3's core PPO algorithm. diff --git a/docs/modules/ppo_recurrent.rst b/docs/modules/ppo_recurrent.rst new file mode 100644 index 00000000..e88f55ed --- /dev/null +++ b/docs/modules/ppo_recurrent.rst @@ -0,0 +1,127 @@ +.. _ppo_mask: + +.. automodule:: sb3_contrib.ppo_recurrent + +Recurrent PPO +============= + +Implementation of recurrent policies for the Proximal Policy Optimization (PPO) +algorithm. Other than adding support for recurrent policies (LSTM here), the behavior is the same as in SB3's core PPO algorithm. + + +.. rubric:: Available Policies + +.. autosummary:: + :nosignatures: + + MlpLstmPolicy + CnnLstmPolicy + MultiInputLstmPolicy + + +Notes +----- + +.. - Paper: https://arxiv.org/abs/2006.14171 +.. - Blog post: https://costa.sh/blog-a-closer-look-at-invalid-action-masking-in-policy-gradient-algorithms.html + + +Can I use? +---------- + +- Recurrent policies: ✔️ +- Multi processing: ✔️ +- Gym spaces: + + +============= ====== =========== +Space Action Observation +============= ====== =========== +Discrete ✔️ ✔️ +Box ✔️ ✔️ +MultiDiscrete ✔️ ✔️ +MultiBinary ✔️ ✔️ +Dict ❌ ✔️ +============= ====== =========== + + +Example +------- + + +.. code-block:: python + + import numpy as np + + from sb3_contrib import RecurrentPPO + from stable_baselines3.common.evaluation import evaluate_policy + + model = RecurrentPPO("MlpLstmPolicy", "CartPole-v1", verbose=1) + model.learn(5000) + + mean_reward, std_reward = evaluate_policy(model, model.get_env(), n_eval_episodes=20, warn=False) + print(mean_reward) + + model.save("ppo_recurrent") + del model # remove to demonstrate saving and loading + + model = RecurrentPPO.load("ppo_recurrent") + + env = model.get_env() + obs = env.reset() + states = None + num_envs = 1 + episode_starts = np.ones((num_envs,), dtype=bool) + while True: + action, states = model.predict(obs, state=states, episode_start=episode_starts, deterministic=True) + obs, rewards, dones, info = env.step(action) + episode_starts = dones + env.render() + + + +Results +------- + +How to replicate the results? +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Clone the repo for the experiment: + +.. code-block:: bash + + git clone https://github.com/DLR-RM/rl-baselines3-zoo + git checkout feat/recurrent-ppo + +Parameters +---------- + +.. autoclass:: RecurrentPPO + :members: + :inherited-members: + + +RecurrentPPO Policies +--------------------- + +.. autoclass:: MlpLstmPolicy + :members: + :inherited-members: + +.. autoclass:: sb3_contrib.common.recurrent.policies.RecurrentActorCriticPolicy + :members: + :noindex: + +.. autoclass:: CnnLstmPolicy + :members: + +.. autoclass:: sb3_contrib.common.recurrent.policies.RecurrentActorCriticCnnPolicy + :members: + :noindex: + +.. autoclass:: MultiInputLstmPolicy + :members: + +.. autoclass:: sb3_contrib.common.recurrent.policies.RecurrentMultiInputActorCriticPolicy + :members: + :noindex: diff --git a/sb3_contrib/__init__.py b/sb3_contrib/__init__.py index 4420815d..8d5bec64 100644 --- a/sb3_contrib/__init__.py +++ b/sb3_contrib/__init__.py @@ -1,7 +1,7 @@ import os -from sb3_contrib.ppo_lstm import RecurrentPPO from sb3_contrib.ppo_mask import MaskablePPO +from sb3_contrib.ppo_recurrent import RecurrentPPO from sb3_contrib.qrdqn import QRDQN from sb3_contrib.tqc import TQC diff --git a/sb3_contrib/common/recurrent/buffers.py b/sb3_contrib/common/recurrent/buffers.py index 8e681ddd..c789f3e6 100644 --- a/sb3_contrib/common/recurrent/buffers.py +++ b/sb3_contrib/common/recurrent/buffers.py @@ -94,6 +94,8 @@ def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentRolloutBuf if batch_size is None: batch_size = self.buffer_size * self.n_envs + # Sampling strategy that allows any mini batch size but requires + # more complexity and use of padding if self.sampling_strategy == "default": # No shuffling # indices = np.arange(self.buffer_size * self.n_envs) @@ -104,6 +106,7 @@ def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentRolloutBuf indices = np.concatenate((indices[split_index:], indices[:split_index])) env_change = np.zeros(self.buffer_size * self.n_envs).reshape(self.buffer_size, self.n_envs) + # Flag first timestep as change of environment env_change[0, :] = 1.0 env_change = self.swap_and_flatten(env_change) @@ -114,8 +117,7 @@ def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentRolloutBuf start_idx += batch_size return - # Baselines way of sampling, constraint in the batch size - # and number of environments + # ==== OpenAI Baselines way of sampling, constraint in the batch size and number of environments ==== n_minibatches = (self.buffer_size * self.n_envs) // batch_size assert ( @@ -131,10 +133,6 @@ def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentRolloutBuf end_env_idx = start_env_idx + n_envs_per_batch mini_batch_env_indices = env_indices[start_env_idx:end_env_idx] batch_inds = flat_indices[mini_batch_env_indices].ravel() - # lstm_states_pi = ( - # self.hidden_states_pi[:, :, mini_batch_env_indices, :][0], - # self.cell_states_pi[:, :, mini_batch_env_indices, :][0], - # ) lstm_states_pi = ( self.initial_lstm_states.pi[0][:, mini_batch_env_indices].clone(), self.initial_lstm_states.pi[1][:, mini_batch_env_indices].clone(), @@ -244,28 +242,29 @@ def __init__( ) self.lstm_states = lstm_states self.sampling_strategy = sampling_strategy + assert sampling_strategy == "default", "'per_env' strategy not supported with dict obs" def reset(self): + super().reset() self.hidden_states_pi = np.zeros_like(self.lstm_states[0]) self.cell_states_pi = np.zeros_like(self.lstm_states[1]) - self.dones = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) - super().reset() + self.hidden_states_vf = np.zeros_like(self.lstm_states[0]) + self.cell_states_vf = np.zeros_like(self.lstm_states[1]) - def add(self, *args, lstm_states: Tuple[np.ndarray, np.ndarray], dones: np.ndarray, **kwargs) -> None: + def add(self, *args, lstm_states: RNNStates, **kwargs) -> None: """ :param hidden_states: LSTM cell and hidden state """ - self.hidden_states_pi[self.pos] = np.array(lstm_states[0]) - self.cell_states_pi[self.pos] = np.array(lstm_states[1]) - self.dones[self.pos] = np.array(dones) + self.hidden_states_pi[self.pos] = np.array(lstm_states.pi[0].cpu().numpy()) + self.cell_states_pi[self.pos] = np.array(lstm_states.pi[1].cpu().numpy()) + self.hidden_states_vf[self.pos] = np.array(lstm_states.vf[0].cpu().numpy()) + self.cell_states_vf[self.pos] = np.array(lstm_states.vf[1].cpu().numpy()) super().add(*args, **kwargs) def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentDictRolloutBufferSamples, None, None]: assert self.full, "" - # indices = np.random.permutation(self.buffer_size * self.n_envs) - # Do not shuffle the data - indices = np.arange(self.buffer_size * self.n_envs) + # Prepare the data if not self.generator_ready: # hidden_state_shape = (self.n_steps, lstm.num_layers, self.n_envs, lstm.hidden_size) @@ -293,22 +292,72 @@ def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentDictRollou if batch_size is None: batch_size = self.buffer_size * self.n_envs + # No shuffling: + # indices = np.arange(self.buffer_size * self.n_envs) + # Trick to shuffle a bit: keep the sequence order + # but split the indices in two + split_index = np.random.randint(self.buffer_size * self.n_envs) + indices = np.arange(self.buffer_size * self.n_envs) + indices = np.concatenate((indices[split_index:], indices[:split_index])) + + env_change = np.zeros(self.buffer_size * self.n_envs).reshape(self.buffer_size, self.n_envs) + # Flag first timestep as change of environment + env_change[0, :] = 1.0 + env_change = self.swap_and_flatten(env_change) + start_idx = 0 while start_idx < self.buffer_size * self.n_envs: - yield self._get_samples(indices[start_idx : start_idx + batch_size]) + batch_inds = indices[start_idx : start_idx + batch_size] + yield self._get_samples(batch_inds, env_change) start_idx += batch_size - def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> RecurrentDictRolloutBufferSamples: + def pad(self, tensor: np.ndarray) -> th.Tensor: + seq = [self.to_torch(tensor[start : end + 1]) for start, end in zip(self.starts, self.ends)] + return th.nn.utils.rnn.pad_sequence(seq) + + def _get_samples( + self, + batch_inds: np.ndarray, + env_change: np.ndarray, + env: Optional[VecNormalize] = None, + ) -> RecurrentDictRolloutBufferSamples: + # Create sequence if env change too + seq_start = np.logical_or(self.episode_starts[batch_inds], env_change[batch_inds]).flatten() + # First index is always the beginning of a sequence + seq_start[0] = True + self.starts = np.where(seq_start == True)[0] # noqa: E712 + self.ends = np.concatenate([(self.starts - 1)[1:], np.array([len(batch_inds)])]) + + n_layers = self.hidden_states_pi.shape[1] + n_seq = len(self.starts) + max_length = self.pad(self.actions[batch_inds]).shape[0] + # TODO: output mask to not backpropagate everywhere + padded_batch_size = n_seq * max_length + lstm_states_pi = ( + # (n_steps, n_layers, n_envs, dim) -> (n_layers, n_seq, dim) + self.hidden_states_pi[batch_inds][seq_start == True].reshape(n_layers, n_seq, -1), # noqa: E712 + self.cell_states_pi[batch_inds][seq_start == True].reshape(n_layers, n_seq, -1), # noqa: E712 + ) + lstm_states_vf = ( + # (n_steps, n_layers, n_envs, dim) -> (n_layers, n_seq, dim) + self.hidden_states_vf[batch_inds][seq_start == True].reshape(n_layers, n_seq, -1), # noqa: E712 + self.cell_states_vf[batch_inds][seq_start == True].reshape(n_layers, n_seq, -1), # noqa: E712 + ) + lstm_states_pi = (self.to_torch(lstm_states_pi[0]), self.to_torch(lstm_states_pi[1])) + lstm_states_vf = (self.to_torch(lstm_states_vf[0]), self.to_torch(lstm_states_vf[1])) + + observations = {key: self.pad(obs[batch_inds]) for (key, obs) in self.observations.items()} + observations = { + key: obs.swapaxes(0, 1).reshape((padded_batch_size,) + self.obs_shape) for (key, obs) in observations.items() + } return RecurrentDictRolloutBufferSamples( - observations={key: self.to_torch(obs[batch_inds]) for (key, obs) in self.observations.items()}, - actions=self.to_torch(self.actions[batch_inds]), - old_values=self.to_torch(self.values[batch_inds].flatten()), - old_log_prob=self.to_torch(self.log_probs[batch_inds].flatten()), - advantages=self.to_torch(self.advantages[batch_inds].flatten()), - returns=self.to_torch(self.returns[batch_inds].flatten()), - lstm_states=RNNStates( - self.to_torch(self.hidden_states_pi[batch_inds]), self.to_torch(self.cell_states_pi[batch_inds]) - ), - episode_starts=self.to_torch(self.episode_starts[batch_inds].flatten()), + observations=observations, + actions=self.pad(self.actions[batch_inds]).swapaxes(0, 1).reshape((padded_batch_size,) + self.actions.shape[1:]), + old_values=self.pad(self.values[batch_inds]).swapaxes(0, 1).flatten(), + old_log_prob=self.pad(self.log_probs[batch_inds]).swapaxes(0, 1).flatten(), + advantages=self.pad(self.advantages[batch_inds]).swapaxes(0, 1).flatten(), + returns=self.pad(self.returns[batch_inds]).swapaxes(0, 1).flatten(), + lstm_states=RNNStates(lstm_states_pi, lstm_states_vf), + episode_starts=self.pad(self.episode_starts[batch_inds]).swapaxes(0, 1).flatten(), ) diff --git a/sb3_contrib/common/recurrent/policies.py b/sb3_contrib/common/recurrent/policies.py index 6bf605f3..8450a6aa 100644 --- a/sb3_contrib/common/recurrent/policies.py +++ b/sb3_contrib/common/recurrent/policies.py @@ -21,8 +21,8 @@ class RecurrentActorCriticPolicy(ActorCriticPolicy): """ - CNN policy class for actor-critic algorithms (has both policy and value prediction). - Used by A2C, PPO and the likes. + Recurrent policy class for actor-critic algorithms (has both policy and value prediction). + To be used with A2C, PPO and the likes. :param observation_space: Observation space :param action_space: Action space @@ -51,6 +51,11 @@ class RecurrentActorCriticPolicy(ActorCriticPolicy): ``th.optim.Adam`` by default :param optimizer_kwargs: Additional keyword arguments, excluding the learning rate, to pass to the optimizer + :param lstm_hidden_size: Number of hidden units for each LSTM layer. + :param n_lstm_layers: Number of LSTM layers. + :param shared_lstm: Whether the LSTM is shared between the actor and the critic. + By default, only the actor has a recurrent network. + :param enable_critic_lstm: Use a seperate LSTM for the critic. """ def __init__( @@ -72,7 +77,7 @@ def __init__( normalize_images: bool = True, optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, optimizer_kwargs: Optional[Dict[str, Any]] = None, - lstm_hidden_size: int = 64, + lstm_hidden_size: int = 256, n_lstm_layers: int = 1, shared_lstm: bool = False, enable_critic_lstm: bool = False, @@ -168,7 +173,10 @@ def forward( """ Forward pass in all the networks (actor and critic) - :param obs: Observation + :param obs: Observation. Observation + :param lstm_states: The last hidden and memory states for the LSTM. + :param episode_starts: Whether the observations correspond to new episodes + or not (we reset the lstm states in that case). :param deterministic: Whether to sample or use deterministic actions :return: action, value and log probability of the action """ @@ -205,7 +213,10 @@ def get_distribution( """ Get the current policy distribution given the observations. - :param obs: + :param obs: Observation. + :param lstm_states: The last hidden and memory states for the LSTM. + :param episode_starts: Whether the observations correspond to new episodes + or not (we reset the lstm states in that case). :return: the action distribution and new hidden states. """ features = self.extract_features(obs) @@ -222,7 +233,10 @@ def predict_values( """ Get the estimated values according to the current policy given the observations. - :param obs: + :param obs: Observation. + :param lstm_states: The last hidden and memory states for the LSTM. + :param episode_starts: Whether the observations correspond to new episodes + or not (we reset the lstm states in that case). :return: the estimated values. """ features = self.extract_features(obs) @@ -249,8 +263,11 @@ def evaluate_actions( Evaluate actions according to the current policy, given the observations. - :param obs: + :param obs: Observation. :param actions: + :param lstm_states: The last hidden and memory states for the LSTM. + :param episode_starts: Whether the observations correspond to new episodes + or not (we reset the lstm states in that case). :return: estimated value, log likelihood of taking those actions and entropy of the action distribution. """ @@ -284,7 +301,9 @@ def _predict( Get the action according to the policy for a given observation. :param observation: - :param lstm_states: + :param lstm_states: The last hidden and memory states for the LSTM. + :param episode_starts: Whether the observations correspond to new episodes + or not (we reset the lstm states in that case). :param deterministic: Whether to use stochastic or deterministic actions :return: Taken action according to the policy and hidden states of the RNN """ @@ -303,10 +322,9 @@ def predict( Includes sugar-coating to handle different observations (e.g. normalizing images). :param observation: the input observation - :param state: The last hidden states (can be None, used in recurrent policies) - :param episode_start: The last masks (can be None, used in recurrent policies) - this correspond to beginning of episodes, - where the hidden states of the RNN must be reset. + :param lstm_states: The last hidden and memory states for the LSTM. + :param episode_starts: Whether the observations correspond to new episodes + or not (we reset the lstm states in that case). :param deterministic: Whether or not to return deterministic actions. :return: the model's action and the next hidden state (used in recurrent policies) @@ -319,6 +337,7 @@ def predict( n_envs = observation.shape[0] # state : (n_layers, n_envs, dim) if state is None: + # Initialize hidden states to zeros state = np.concatenate([np.zeros(self.lstm_shape) for _ in range(n_envs)], axis=1) state = (state, state) @@ -355,7 +374,7 @@ def predict( class RecurrentActorCriticCnnPolicy(RecurrentActorCriticPolicy): """ - CNN policy class for actor-critic algorithms (has both policy and value prediction). + CNN recurrent policy class for actor-critic algorithms (has both policy and value prediction). Used by A2C, PPO and the likes. :param observation_space: Observation space @@ -385,6 +404,12 @@ class RecurrentActorCriticCnnPolicy(RecurrentActorCriticPolicy): ``th.optim.Adam`` by default :param optimizer_kwargs: Additional keyword arguments, excluding the learning rate, to pass to the optimizer + :param lstm_hidden_size: Number of hidden units for each LSTM layer. + :param n_lstm_layers: Number of LSTM layers. + :param shared_lstm: Whether the LSTM is shared between the actor and the critic. + By default, only the actor has a recurrent network. + :param enable_critic_lstm: Use a seperate LSTM for the critic. + """ def __init__( @@ -406,7 +431,7 @@ def __init__( normalize_images: bool = True, optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, optimizer_kwargs: Optional[Dict[str, Any]] = None, - lstm_hidden_size: int = 64, + lstm_hidden_size: int = 256, n_lstm_layers: int = 1, enable_critic_lstm: bool = False, ): @@ -434,7 +459,7 @@ def __init__( ) -class MultiInputRecurrentActorCriticPolicy(RecurrentActorCriticPolicy): +class RecurrentMultiInputActorCriticPolicy(RecurrentActorCriticPolicy): """ MultiInputActorClass policy class for actor-critic algorithms (has both policy and value prediction). Used by A2C, PPO and the likes. @@ -466,6 +491,11 @@ class MultiInputRecurrentActorCriticPolicy(RecurrentActorCriticPolicy): ``th.optim.Adam`` by default :param optimizer_kwargs: Additional keyword arguments, excluding the learning rate, to pass to the optimizer + :param lstm_hidden_size: Number of hidden units for each LSTM layer. + :param n_lstm_layers: Number of LSTM layers. + :param shared_lstm: Whether the LSTM is shared between the actor and the critic. + By default, only the actor has a recurrent network. + :param enable_critic_lstm: Use a seperate LSTM for the critic. """ def __init__( @@ -487,7 +517,7 @@ def __init__( normalize_images: bool = True, optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, optimizer_kwargs: Optional[Dict[str, Any]] = None, - lstm_hidden_size: int = 64, + lstm_hidden_size: int = 256, n_lstm_layers: int = 1, enable_critic_lstm: bool = False, ): diff --git a/sb3_contrib/ppo_lstm/__init__.py b/sb3_contrib/ppo_lstm/__init__.py deleted file mode 100644 index 9f4fdd39..00000000 --- a/sb3_contrib/ppo_lstm/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from sb3_contrib.ppo_lstm.policies import CnnLstmPolicy, MlpLstmPolicy -from sb3_contrib.ppo_lstm.ppo_lstm import RecurrentPPO diff --git a/sb3_contrib/ppo_lstm/policies.py b/sb3_contrib/ppo_lstm/policies.py deleted file mode 100644 index 584a60b5..00000000 --- a/sb3_contrib/ppo_lstm/policies.py +++ /dev/null @@ -1,12 +0,0 @@ -from stable_baselines3.common.policies import register_policy - -from sb3_contrib.common.recurrent.policies import RecurrentActorCriticCnnPolicy # RecurrentMultiInputActorCriticPolicy, -from sb3_contrib.common.recurrent.policies import RecurrentActorCriticPolicy - -MlpLstmPolicy = RecurrentActorCriticPolicy -CnnLstmPolicy = RecurrentActorCriticCnnPolicy -# MultiInputLstmPolicy = RecurrentMultiInputActorCriticPolicy - -register_policy("MlpLstmPolicy", RecurrentActorCriticPolicy) -register_policy("CnnLstmPolicy", RecurrentActorCriticCnnPolicy) -# register_policy("MultiInputLstmPolicy", RecurrentMultiInputActorCriticPolicy) diff --git a/sb3_contrib/ppo_recurrent/__init__.py b/sb3_contrib/ppo_recurrent/__init__.py new file mode 100644 index 00000000..3fb5436e --- /dev/null +++ b/sb3_contrib/ppo_recurrent/__init__.py @@ -0,0 +1,2 @@ +from sb3_contrib.ppo_recurrent.policies import CnnLstmPolicy, MlpLstmPolicy, MultiInputLstmPolicy +from sb3_contrib.ppo_recurrent.ppo_recurrent import RecurrentPPO diff --git a/sb3_contrib/ppo_recurrent/policies.py b/sb3_contrib/ppo_recurrent/policies.py new file mode 100644 index 00000000..ce1fe7d3 --- /dev/null +++ b/sb3_contrib/ppo_recurrent/policies.py @@ -0,0 +1,15 @@ +from stable_baselines3.common.policies import register_policy + +from sb3_contrib.common.recurrent.policies import ( + RecurrentActorCriticCnnPolicy, + RecurrentActorCriticPolicy, + RecurrentMultiInputActorCriticPolicy, +) + +MlpLstmPolicy = RecurrentActorCriticPolicy +CnnLstmPolicy = RecurrentActorCriticCnnPolicy +MultiInputLstmPolicy = RecurrentMultiInputActorCriticPolicy + +register_policy("MlpLstmPolicy", RecurrentActorCriticPolicy) +register_policy("CnnLstmPolicy", RecurrentActorCriticCnnPolicy) +register_policy("MultiInputLstmPolicy", RecurrentMultiInputActorCriticPolicy) diff --git a/sb3_contrib/ppo_lstm/ppo_lstm.py b/sb3_contrib/ppo_recurrent/ppo_recurrent.py similarity index 100% rename from sb3_contrib/ppo_lstm/ppo_lstm.py rename to sb3_contrib/ppo_recurrent/ppo_recurrent.py diff --git a/setup.cfg b/setup.cfg index 2e8b5010..adf3ef94 100644 --- a/setup.cfg +++ b/setup.cfg @@ -23,7 +23,7 @@ ignore = W503,W504,E203,E231 # line breaks before and after binary operators per-file-ignores = ./sb3_contrib/__init__.py:F401 ./sb3_contrib/ppo_mask/__init__.py:F401 - ./sb3_contrib/ppo_lstm/__init__.py:F401 + ./sb3_contrib/ppo_recurrent/__init__.py:F401 ./sb3_contrib/qrdqn/__init__.py:F401 ./sb3_contrib/tqc/__init__.py:F401 ./sb3_contrib/common/vec_env/wrappers/__init__.py:F401 From a4b769f9285305742bcee13955d337842f0623c9 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Mon, 27 Dec 2021 15:12:16 +0100 Subject: [PATCH 26/50] Fixes for dict obs support --- sb3_contrib/common/recurrent/buffers.py | 12 ++++---- sb3_contrib/common/recurrent/policies.py | 5 +++- tests/test_lstm.py | 36 +++++++++++++++++++----- 3 files changed, 39 insertions(+), 14 deletions(-) diff --git a/sb3_contrib/common/recurrent/buffers.py b/sb3_contrib/common/recurrent/buffers.py index c789f3e6..57063c64 100644 --- a/sb3_contrib/common/recurrent/buffers.py +++ b/sb3_contrib/common/recurrent/buffers.py @@ -40,7 +40,6 @@ def __init__( sampling_strategy: str = "default", # "default" or "per_env" ): self.lstm_states = lstm_states - # self.dones = None self.initial_lstm_states = None self.sampling_strategy = sampling_strategy self.starts, self.ends = None, None @@ -237,12 +236,11 @@ def __init__( n_envs: int = 1, sampling_strategy: str = "default", # "default" or "per_env" ): - super(RecurrentDictRolloutBuffer, self).__init__( - buffer_size, observation_space, action_space, device, gae_lambda, gamma, n_envs=n_envs - ) self.lstm_states = lstm_states + self.initial_lstm_states = None self.sampling_strategy = sampling_strategy assert sampling_strategy == "default", "'per_env' strategy not supported with dict obs" + super().__init__(buffer_size, observation_space, action_space, device, gae_lambda, gamma, n_envs=n_envs) def reset(self): super().reset() @@ -272,8 +270,10 @@ def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentDictRollou for tensor in ["hidden_states_pi", "cell_states_pi", "hidden_states_vf", "cell_states_vf"]: self.__dict__[tensor] = self.__dict__[tensor].swapaxes(1, 2) + for key, obs in self.observations.items(): + self.observations[key] = self.swap_and_flatten(obs) + for tensor in [ - "observations", "actions", "values", "log_probs", @@ -348,7 +348,7 @@ def _get_samples( observations = {key: self.pad(obs[batch_inds]) for (key, obs) in self.observations.items()} observations = { - key: obs.swapaxes(0, 1).reshape((padded_batch_size,) + self.obs_shape) for (key, obs) in observations.items() + key: obs.swapaxes(0, 1).reshape((padded_batch_size,) + self.obs_shape[key]) for (key, obs) in observations.items() } return RecurrentDictRolloutBufferSamples( diff --git a/sb3_contrib/common/recurrent/policies.py b/sb3_contrib/common/recurrent/policies.py index 8450a6aa..2b97fd69 100644 --- a/sb3_contrib/common/recurrent/policies.py +++ b/sb3_contrib/common/recurrent/policies.py @@ -334,7 +334,10 @@ def predict( observation, vectorized_env = self.obs_to_tensor(observation) - n_envs = observation.shape[0] + if isinstance(observation, dict): + n_envs = observation[list(observation.keys())[0]].shape[0] + else: + n_envs = observation.shape[0] # state : (n_layers, n_envs, dim) if state is None: # Initialize hidden states to zeros diff --git a/tests/test_lstm.py b/tests/test_lstm.py index 7eade194..12e6961f 100644 --- a/tests/test_lstm.py +++ b/tests/test_lstm.py @@ -1,3 +1,4 @@ +import gym import numpy as np import pytest from gym import spaces @@ -11,6 +12,23 @@ from sb3_contrib import RecurrentPPO +class ToDictWrapper(gym.Wrapper): + """ + Simple wrapper to test MultInputPolicy on Dict obs. + """ + + def __init__(self, env): + super().__init__(env) + self.observation_space = gym.spaces.Dict({"obs": self.env.observation_space}) + + def reset(self): + return {"obs": self.env.reset()} + + def step(self, action): + obs, reward, done, infos = self.env.step(action) + return {"obs": obs}, reward, done, infos + + class CartPoleNoVelEnv(CartPoleEnv): """Variant of CartPoleEnv with velocity information removed. This task requires memory to solve.""" @@ -90,17 +108,21 @@ def test_run_sde(): model.learn(total_timesteps=32, eval_freq=16) -def test_ppo_lstm_performance(): +def test_dict_obs(): + env = make_vec_env("CartPole-v1", n_envs=1, wrapper_class=ToDictWrapper) + model = RecurrentPPO("MultiInputLstmPolicy", env, n_steps=32).learn(64) + evaluate_policy(model, env, warn=False) + +def test_ppo_lstm_performance(): # env = make_vec_env("CartPole-v1", n_envs=16) - # env = make_vec_env("Pendulum-v0", n_envs=16) def make_env(): env = CartPoleNoVelEnv() env = TimeLimit(env, max_episode_steps=500) return env - env = make_vec_env(make_env, n_envs=16) + env = make_vec_env(make_env, n_envs=8) # eval_callback = EvalCallback( # make_vec_env(make_env, n_envs=4), @@ -111,7 +133,7 @@ def make_env(): model = RecurrentPPO( "MlpLstmPolicy", env, - n_steps=32, + n_steps=128, learning_rate=0.0007, verbose=1, batch_size=256, @@ -123,7 +145,7 @@ def make_env(): # policy_kwargs=dict(net_arch=[dict(pi=[64], vf=[64])]) ) - model.learn(total_timesteps=250) - # model.learn(total_timesteps=100_000) + # model.learn(total_timesteps=250) + model.learn(total_timesteps=100_000) # model.learn(total_timesteps=1000, callback=eval_callback) - evaluate_policy(model, env) + evaluate_policy(model, env, reward_threshold=100) From 5cadc1475014c1f3b46535f16e100f2f2c40a26c Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Mon, 27 Dec 2021 15:30:57 +0100 Subject: [PATCH 27/50] Do not run slow tests --- scripts/run_tests.sh | 2 +- setup.cfg | 2 ++ tests/test_lstm.py | 30 +++++++++++++++--------------- 3 files changed, 18 insertions(+), 16 deletions(-) diff --git a/scripts/run_tests.sh b/scripts/run_tests.sh index a3795075..45931f0b 100755 --- a/scripts/run_tests.sh +++ b/scripts/run_tests.sh @@ -1,2 +1,2 @@ #!/bin/bash -python3 -m pytest --cov-config .coveragerc --cov-report html --cov-report term --cov=. -v --color=yes +python3 -m pytest --cov-config .coveragerc --cov-report html --cov-report term --cov=. -v --color=yes -m "not slow" diff --git a/setup.cfg b/setup.cfg index adf3ef94..079e9ac5 100644 --- a/setup.cfg +++ b/setup.cfg @@ -13,6 +13,8 @@ filterwarnings = ignore:Parameters to load are deprecated.:DeprecationWarning ignore:the imp module is deprecated in favour of importlib:PendingDeprecationWarning ignore::UserWarning:gym +markers = + slow: marks tests as slow (deselect with '-m "not slow"') [pytype] inputs = sb3_contrib diff --git a/tests/test_lstm.py b/tests/test_lstm.py index 12e6961f..b30d2318 100644 --- a/tests/test_lstm.py +++ b/tests/test_lstm.py @@ -4,8 +4,7 @@ from gym import spaces from gym.envs.classic_control import CartPoleEnv from gym.wrappers.time_limit import TimeLimit - -# from stable_baselines3.common.callbacks import EvalCallback +from stable_baselines3.common.callbacks import EvalCallback from stable_baselines3.common.env_util import make_vec_env from stable_baselines3.common.evaluation import evaluate_policy @@ -114,9 +113,9 @@ def test_dict_obs(): evaluate_policy(model, env, warn=False) +@pytest.mark.slow def test_ppo_lstm_performance(): # env = make_vec_env("CartPole-v1", n_envs=16) - def make_env(): env = CartPoleNoVelEnv() env = TimeLimit(env, max_episode_steps=500) @@ -124,11 +123,11 @@ def make_env(): env = make_vec_env(make_env, n_envs=8) - # eval_callback = EvalCallback( - # make_vec_env(make_env, n_envs=4), - # n_eval_episodes=20, - # eval_freq=250 // env.num_envs, - # ) + eval_callback = EvalCallback( + make_vec_env(make_env, n_envs=4), + n_eval_episodes=20, + eval_freq=5000 // env.num_envs, + ) model = RecurrentPPO( "MlpLstmPolicy", @@ -138,14 +137,15 @@ def make_env(): verbose=1, batch_size=256, seed=0, - n_epochs=20, - # max_grad_norm=1, + n_epochs=10, + max_grad_norm=1, gae_lambda=0.98, policy_kwargs=dict(net_arch=[dict(vf=[64])], ortho_init=False), - # policy_kwargs=dict(net_arch=[dict(pi=[64], vf=[64])]) ) - # model.learn(total_timesteps=250) - model.learn(total_timesteps=100_000) - # model.learn(total_timesteps=1000, callback=eval_callback) - evaluate_policy(model, env, reward_threshold=100) + model.learn(total_timesteps=50_000, callback=eval_callback) + # Maximum episode reward is 500. + # In CartPole-v1, a non-recurrent policy can easily get >= 450. + # In CartPoleNoVelEnv, a non-recurrent policy doesn't get more than ~50. + # LSTM policies can reach above 400, but it varies a lot between runs; consistently get >=150. + evaluate_policy(model, env, reward_threshold=160) From c1f88128c757927f25ab2e18af9d12ff554e425c Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Wed, 29 Dec 2021 12:16:15 +0100 Subject: [PATCH 28/50] Fix doc --- docs/guide/examples.rst | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/docs/guide/examples.rst b/docs/guide/examples.rst index 3af55613..c39ce260 100644 --- a/docs/guide/examples.rst +++ b/docs/guide/examples.rst @@ -45,15 +45,15 @@ Train a PPO with invalid action masking agent on a toy environment. model.learn(5000) model.save("qrdqn_cartpole") - TRPO - ---- +TRPO +---- - Train a Trust Region Policy Optimization (TRPO) agent on the Pendulum environment. +Train a Trust Region Policy Optimization (TRPO) agent on the Pendulum environment. - .. code-block:: python +.. code-block:: python - from sb3_contrib import TRPO + from sb3_contrib import TRPO - model = TRPO("MlpPolicy", "Pendulum-v0", gamma=0.9, verbose=1) - model.learn(total_timesteps=100_000, log_interval=4) - model.save("trpo_pendulum") + model = TRPO("MlpPolicy", "Pendulum-v0", gamma=0.9, verbose=1) + model.learn(total_timesteps=100_000, log_interval=4) + model.save("trpo_pendulum") From 579e7d0bbc2123188175cb3f7242dde7236483e2 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Wed, 29 Dec 2021 12:19:51 +0100 Subject: [PATCH 29/50] Update recurrent PPO example --- docs/guide/examples.rst | 25 +++++++++++++++++++++++++ docs/modules/ppo_recurrent.rst | 8 ++++---- 2 files changed, 29 insertions(+), 4 deletions(-) diff --git a/docs/guide/examples.rst b/docs/guide/examples.rst index c39ce260..32554b22 100644 --- a/docs/guide/examples.rst +++ b/docs/guide/examples.rst @@ -57,3 +57,28 @@ Train a Trust Region Policy Optimization (TRPO) agent on the Pendulum environmen model = TRPO("MlpPolicy", "Pendulum-v0", gamma=0.9, verbose=1) model.learn(total_timesteps=100_000, log_interval=4) model.save("trpo_pendulum") + +RecurrentPPO +------------ + +Train a PPO agent with a recurrent policy on the CartPole environment. + +.. code-block:: python + + import numpy as np + + from sb3_contrib import RecurrentPPO + + model = RecurrentPPO("MlpLstmPolicy", "CartPole-v1", verbose=1) + model.learn(5000) + + env = model.get_env() + obs = env.reset() + lstm_states = None + num_envs = 1 + episode_starts = np.ones((num_envs,), dtype=bool) + while True: + action, lstm_states = model.predict(obs, state=lstm_states, episode_start=episode_starts, deterministic=True) + obs, rewards, dones, info = env.step(action) + episode_starts = dones + env.render() diff --git a/docs/modules/ppo_recurrent.rst b/docs/modules/ppo_recurrent.rst index e88f55ed..22ea92d2 100644 --- a/docs/modules/ppo_recurrent.rst +++ b/docs/modules/ppo_recurrent.rst @@ -59,7 +59,8 @@ Example model = RecurrentPPO("MlpLstmPolicy", "CartPole-v1", verbose=1) model.learn(5000) - mean_reward, std_reward = evaluate_policy(model, model.get_env(), n_eval_episodes=20, warn=False) + env = model.get_env() + mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=20, warn=False) print(mean_reward) model.save("ppo_recurrent") @@ -67,13 +68,12 @@ Example model = RecurrentPPO.load("ppo_recurrent") - env = model.get_env() obs = env.reset() - states = None + lstm_states = None num_envs = 1 episode_starts = np.ones((num_envs,), dtype=bool) while True: - action, states = model.predict(obs, state=states, episode_start=episode_starts, deterministic=True) + action, lstm_states = model.predict(obs, state=lstm_states, episode_start=episode_starts, deterministic=True) obs, rewards, dones, info = env.step(action) episode_starts = dones env.render() From bd2d5e28aa211788e288df2fd1fa4e67f12e70b1 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Wed, 29 Dec 2021 12:25:31 +0100 Subject: [PATCH 30/50] Update README --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 815956d6..f452de7c 100644 --- a/README.md +++ b/README.md @@ -29,6 +29,7 @@ See documentation for the full list of included features. - [Quantile Regression DQN (QR-DQN)](https://arxiv.org/abs/1710.10044) - [PPO with invalid action masking (MaskablePPO)](https://arxiv.org/abs/2006.14171) - [Trust Region Policy Optimization (TRPO)](https://arxiv.org/abs/1502.05477) +- [PPO with recurrent policy (RecurrentPPO)](https://arxiv.org/abs/1707.06347) **Gym Wrappers**: - [Time Feature Wrapper](https://arxiv.org/abs/1712.00378) From 0f0ce0be8dd6f132a92ef1eee1a1e994f6b0105e Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Wed, 23 Feb 2022 12:04:55 +0100 Subject: [PATCH 31/50] Use Pendulum-v1 for tests --- tests/test_lstm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_lstm.py b/tests/test_lstm.py index b30d2318..6993aa07 100644 --- a/tests/test_lstm.py +++ b/tests/test_lstm.py @@ -80,7 +80,7 @@ def test_policy_kwargs(policy_kwargs): model.learn(total_timesteps=32) -@pytest.mark.parametrize("env", ["Pendulum-v0", "CartPole-v1"]) +@pytest.mark.parametrize("env", ["Pendulum-v1", "CartPole-v1"]) def test_run(env): model = RecurrentPPO( "MlpLstmPolicy", @@ -96,7 +96,7 @@ def test_run(env): def test_run_sde(): model = RecurrentPPO( "MlpLstmPolicy", - "Pendulum-v0", + "Pendulum-v1", n_steps=16, seed=0, create_eval_env=True, From 116d0a6eadde1373c5f89bc01ebdef61311deddb Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Wed, 23 Feb 2022 13:08:15 +0100 Subject: [PATCH 32/50] Fix image env --- tests/test_lstm.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_lstm.py b/tests/test_lstm.py index 6993aa07..5310abab 100644 --- a/tests/test_lstm.py +++ b/tests/test_lstm.py @@ -6,6 +6,7 @@ from gym.wrappers.time_limit import TimeLimit from stable_baselines3.common.callbacks import EvalCallback from stable_baselines3.common.env_util import make_vec_env +from stable_baselines3.common.envs import FakeImageEnv from stable_baselines3.common.evaluation import evaluate_policy from sb3_contrib import RecurrentPPO @@ -58,7 +59,7 @@ def step(self, action): def test_cnn(): model = RecurrentPPO( "CnnLstmPolicy", - "Breakout-v0", + FakeImageEnv(screen_height=40, screen_width=40, n_channels=3), n_steps=16, seed=0, policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=32)), From c32bb74fba2ea6690f0e71d395ddd2162b82db40 Mon Sep 17 00:00:00 2001 From: Neville Walo <43504521+Walon1998@users.noreply.github.com> Date: Tue, 8 Mar 2022 22:31:35 +0100 Subject: [PATCH 33/50] Speedup LSTM forward pass (#63) * added more efficient lstm implementation * Rename and add comment Co-authored-by: Antonin Raffin --- sb3_contrib/common/recurrent/policies.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/sb3_contrib/common/recurrent/policies.py b/sb3_contrib/common/recurrent/policies.py index 2b97fd69..711767d2 100644 --- a/sb3_contrib/common/recurrent/policies.py +++ b/sb3_contrib/common/recurrent/policies.py @@ -148,6 +148,13 @@ def _process_sequence( features_sequence = features.reshape((n_envs, -1, lstm.input_size)).swapaxes(0, 1) episode_starts = episode_starts.reshape((n_envs, -1)).swapaxes(0, 1) + # If we don't have to reset the state in the middle of a sequence + # we can avoid the for loop, which speeds up things + if th.all(episode_starts == 0.0): + lstm_output, lstm_states = lstm(features_sequence, lstm_states) + lstm_output = th.flatten(lstm_output.transpose(0, 1), start_dim=0, end_dim=1) + return lstm_output, lstm_states + lstm_output = [] # Iterate over the sequence for features, episode_start in zip_strict(features_sequence, episode_starts): From 86e0f6fae24b6e1618ecbe9705125528c451ae0d Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Tue, 12 Apr 2022 22:40:58 +0200 Subject: [PATCH 34/50] Fixes --- sb3_contrib/ppo_recurrent/policies.py | 6 ------ sb3_contrib/ppo_recurrent/ppo_recurrent.py | 9 ++++++++- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/sb3_contrib/ppo_recurrent/policies.py b/sb3_contrib/ppo_recurrent/policies.py index ce1fe7d3..d9b37458 100644 --- a/sb3_contrib/ppo_recurrent/policies.py +++ b/sb3_contrib/ppo_recurrent/policies.py @@ -1,5 +1,3 @@ -from stable_baselines3.common.policies import register_policy - from sb3_contrib.common.recurrent.policies import ( RecurrentActorCriticCnnPolicy, RecurrentActorCriticPolicy, @@ -9,7 +7,3 @@ MlpLstmPolicy = RecurrentActorCriticPolicy CnnLstmPolicy = RecurrentActorCriticCnnPolicy MultiInputLstmPolicy = RecurrentMultiInputActorCriticPolicy - -register_policy("MlpLstmPolicy", RecurrentActorCriticPolicy) -register_policy("CnnLstmPolicy", RecurrentActorCriticCnnPolicy) -register_policy("MultiInputLstmPolicy", RecurrentMultiInputActorCriticPolicy) diff --git a/sb3_contrib/ppo_recurrent/ppo_recurrent.py b/sb3_contrib/ppo_recurrent/ppo_recurrent.py index 2e8ba40c..a650cffb 100644 --- a/sb3_contrib/ppo_recurrent/ppo_recurrent.py +++ b/sb3_contrib/ppo_recurrent/ppo_recurrent.py @@ -9,7 +9,7 @@ from stable_baselines3.common.buffers import RolloutBuffer from stable_baselines3.common.callbacks import BaseCallback from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm -from stable_baselines3.common.policies import ActorCriticPolicy +from stable_baselines3.common.policies import ActorCriticPolicy, BasePolicy from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule from stable_baselines3.common.utils import explained_variance, get_schedule_fn, obs_as_tensor, safe_mean from stable_baselines3.common.vec_env import VecEnv @@ -18,6 +18,7 @@ from sb3_contrib.common.recurrent.buffers import RecurrentDictRolloutBuffer, RecurrentRolloutBuffer from sb3_contrib.common.recurrent.policies import RecurrentActorCriticPolicy from sb3_contrib.common.recurrent.type_aliases import RNNStates +from sb3_contrib.ppo_recurrent.policies import CnnLstmPolicy, MlpLstmPolicy, MultiInputLstmPolicy class RecurrentPPO(OnPolicyAlgorithm): @@ -64,6 +65,12 @@ class RecurrentPPO(OnPolicyAlgorithm): :param _init_setup_model: Whether or not to build the network at the creation of the instance """ + policy_aliases: Dict[str, Type[BasePolicy]] = { + "MlpLstmPolicy": MlpLstmPolicy, + "CnnLstmPolicy": CnnLstmPolicy, + "MultiInputLstmPolicy": MultiInputLstmPolicy, + } + def __init__( self, policy: Union[str, Type[RecurrentActorCriticPolicy]], From 3fc6e518e6b85cf74fcd59c8a845ed575a3513c7 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Tue, 12 Apr 2022 22:57:11 +0200 Subject: [PATCH 35/50] Remove OpenAI sampling and improve coverage --- sb3_contrib/common/recurrent/buffers.py | 90 ++++++---------------- sb3_contrib/ppo_recurrent/ppo_recurrent.py | 6 +- tests/test_lstm.py | 3 +- 3 files changed, 26 insertions(+), 73 deletions(-) diff --git a/sb3_contrib/common/recurrent/buffers.py b/sb3_contrib/common/recurrent/buffers.py index 57063c64..5feff9ff 100644 --- a/sb3_contrib/common/recurrent/buffers.py +++ b/sb3_contrib/common/recurrent/buffers.py @@ -20,7 +20,8 @@ class RecurrentRolloutBuffer(RolloutBuffer): :param buffer_size: Max number of element in the buffer :param observation_space: Observation space :param action_space: Action space - :param device: + :param lstm_states: Dummy LSTM states to have the correct shapes when reseting the buffer + :param device: PyTorch device :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator Equivalent to classic advantage when set to 1. :param gamma: Discount factor @@ -37,11 +38,9 @@ def __init__( gae_lambda: float = 1, gamma: float = 0.99, n_envs: int = 1, - sampling_strategy: str = "default", # "default" or "per_env" ): self.lstm_states = lstm_states self.initial_lstm_states = None - self.sampling_strategy = sampling_strategy self.starts, self.ends = None, None super().__init__(buffer_size, observation_space, action_space, device, gae_lambda, gamma, n_envs) @@ -64,7 +63,7 @@ def add(self, *args, lstm_states: RNNStates, **kwargs) -> None: super().add(*args, **kwargs) def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentRolloutBufferSamples, None, None]: - assert self.full, "" + assert self.full, "Rollout buffer must be full before sampling from it" # Prepare the data if not self.generator_ready: @@ -95,62 +94,22 @@ def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentRolloutBuf # Sampling strategy that allows any mini batch size but requires # more complexity and use of padding - if self.sampling_strategy == "default": - # No shuffling - # indices = np.arange(self.buffer_size * self.n_envs) - # Trick to shuffle a bit: keep the sequence order - # but split the indices in two - split_index = np.random.randint(self.buffer_size * self.n_envs) - indices = np.arange(self.buffer_size * self.n_envs) - indices = np.concatenate((indices[split_index:], indices[:split_index])) - - env_change = np.zeros(self.buffer_size * self.n_envs).reshape(self.buffer_size, self.n_envs) - # Flag first timestep as change of environment - env_change[0, :] = 1.0 - env_change = self.swap_and_flatten(env_change) - - start_idx = 0 - while start_idx < self.buffer_size * self.n_envs: - batch_inds = indices[start_idx : start_idx + batch_size] - yield self._get_samples(batch_inds, env_change) - start_idx += batch_size - return - - # ==== OpenAI Baselines way of sampling, constraint in the batch size and number of environments ==== - n_minibatches = (self.buffer_size * self.n_envs) // batch_size - - assert ( - self.n_envs % n_minibatches == 0 - ), f"{self.n_envs} not a multiple of {n_minibatches} = {self.buffer_size * self.n_envs} // {batch_size}" - n_envs_per_batch = self.n_envs // n_minibatches - - # Do not shuffle the sequence, only the env indices - env_indices = np.random.permutation(self.n_envs) - flat_indices = np.arange(self.buffer_size * self.n_envs).reshape(self.n_envs, self.buffer_size) - - for start_env_idx in range(0, self.n_envs, n_envs_per_batch): - end_env_idx = start_env_idx + n_envs_per_batch - mini_batch_env_indices = env_indices[start_env_idx:end_env_idx] - batch_inds = flat_indices[mini_batch_env_indices].ravel() - lstm_states_pi = ( - self.initial_lstm_states.pi[0][:, mini_batch_env_indices].clone(), - self.initial_lstm_states.pi[1][:, mini_batch_env_indices].clone(), - ) - lstm_states_vf = ( - self.initial_lstm_states.vf[0][:, mini_batch_env_indices].clone(), - self.initial_lstm_states.vf[1][:, mini_batch_env_indices].clone(), - ) - - yield RecurrentRolloutBufferSamples( - observations=self.to_torch(self.observations[batch_inds]), - actions=self.to_torch(self.actions[batch_inds]), - old_values=self.to_torch(self.values[batch_inds].flatten()), - old_log_prob=self.to_torch(self.log_probs[batch_inds].flatten()), - advantages=self.to_torch(self.advantages[batch_inds].flatten()), - returns=self.to_torch(self.returns[batch_inds].flatten()), - lstm_states=RNNStates(lstm_states_pi, lstm_states_vf), - episode_starts=self.to_torch(self.episode_starts[batch_inds].flatten()), - ) + # Trick to shuffle a bit: keep the sequence order + # but split the indices in two + split_index = np.random.randint(self.buffer_size * self.n_envs) + indices = np.arange(self.buffer_size * self.n_envs) + indices = np.concatenate((indices[split_index:], indices[:split_index])) + + env_change = np.zeros(self.buffer_size * self.n_envs).reshape(self.buffer_size, self.n_envs) + # Flag first timestep as change of environment + env_change[0, :] = 1.0 + env_change = self.swap_and_flatten(env_change) + + start_idx = 0 + while start_idx < self.buffer_size * self.n_envs: + batch_inds = indices[start_idx : start_idx + batch_size] + yield self._get_samples(batch_inds, env_change) + start_idx += batch_size def pad(self, tensor: np.ndarray) -> th.Tensor: seq = [self.to_torch(tensor[start : end + 1]) for start, end in zip(self.starts, self.ends)] @@ -217,7 +176,8 @@ class RecurrentDictRolloutBuffer(DictRolloutBuffer): :param buffer_size: Max number of element in the buffer :param observation_space: Observation space :param action_space: Action space - :param device: + :param lstm_states: Dummy LSTM states to have the correct shapes when reseting the buffer + :param device: PyTorch device :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator Equivalent to classic advantage when set to 1. :param gamma: Discount factor @@ -234,12 +194,10 @@ def __init__( gae_lambda: float = 1, gamma: float = 0.99, n_envs: int = 1, - sampling_strategy: str = "default", # "default" or "per_env" ): self.lstm_states = lstm_states self.initial_lstm_states = None - self.sampling_strategy = sampling_strategy - assert sampling_strategy == "default", "'per_env' strategy not supported with dict obs" + self.starts, self.ends = None, None super().__init__(buffer_size, observation_space, action_space, device, gae_lambda, gamma, n_envs=n_envs) def reset(self): @@ -261,7 +219,7 @@ def add(self, *args, lstm_states: RNNStates, **kwargs) -> None: super().add(*args, **kwargs) def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentDictRolloutBufferSamples, None, None]: - assert self.full, "" + assert self.full, "Rollout buffer must be full before sampling from it" # Prepare the data if not self.generator_ready: @@ -292,8 +250,6 @@ def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentDictRollou if batch_size is None: batch_size = self.buffer_size * self.n_envs - # No shuffling: - # indices = np.arange(self.buffer_size * self.n_envs) # Trick to shuffle a bit: keep the sequence order # but split the indices in two split_index = np.random.randint(self.buffer_size * self.n_envs) diff --git a/sb3_contrib/ppo_recurrent/ppo_recurrent.py b/sb3_contrib/ppo_recurrent/ppo_recurrent.py index a650cffb..f4b1c489 100644 --- a/sb3_contrib/ppo_recurrent/ppo_recurrent.py +++ b/sb3_contrib/ppo_recurrent/ppo_recurrent.py @@ -9,7 +9,7 @@ from stable_baselines3.common.buffers import RolloutBuffer from stable_baselines3.common.callbacks import BaseCallback from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm -from stable_baselines3.common.policies import ActorCriticPolicy, BasePolicy +from stable_baselines3.common.policies import BasePolicy from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule from stable_baselines3.common.utils import explained_variance, get_schedule_fn, obs_as_tensor, safe_mean from stable_baselines3.common.vec_env import VecEnv @@ -89,7 +89,6 @@ def __init__( use_sde: bool = False, sde_sample_freq: int = -1, target_kl: Optional[float] = None, - sampling_strategy: str = "default", # "default" or "per_env" tensorboard_log: Optional[str] = None, create_eval_env: bool = False, policy_kwargs: Optional[Dict[str, Any]] = None, @@ -113,7 +112,6 @@ def __init__( tensorboard_log=tensorboard_log, create_eval_env=create_eval_env, policy_kwargs=policy_kwargs, - policy_base=ActorCriticPolicy, verbose=verbose, seed=seed, device=device, @@ -132,7 +130,6 @@ def __init__( self.clip_range_vf = clip_range_vf self.target_kl = target_kl self._last_lstm_states = None - self.sampling_strategy = sampling_strategy if _init_setup_model: self._setup_model() @@ -184,7 +181,6 @@ def _setup_model(self) -> None: gamma=self.gamma, gae_lambda=self.gae_lambda, n_envs=self.n_envs, - sampling_strategy=self.sampling_strategy, ) # Initialize schedules for policy/value clipping diff --git a/tests/test_lstm.py b/tests/test_lstm.py index 5310abab..d8a76f86 100644 --- a/tests/test_lstm.py +++ b/tests/test_lstm.py @@ -103,9 +103,10 @@ def test_run_sde(): create_eval_env=True, sde_sample_freq=4, use_sde=True, + clip_range_vf=0.1, ) - model.learn(total_timesteps=32, eval_freq=16) + model.learn(total_timesteps=200, eval_freq=150) def test_dict_obs(): From 88f950402f895ecac06081f45a61ed891dbfdf1d Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Tue, 12 Apr 2022 23:18:22 +0200 Subject: [PATCH 36/50] Sync with SB3 PPO --- sb3_contrib/ppo_recurrent/ppo_recurrent.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/sb3_contrib/ppo_recurrent/ppo_recurrent.py b/sb3_contrib/ppo_recurrent/ppo_recurrent.py index f4b1c489..1700d29c 100644 --- a/sb3_contrib/ppo_recurrent/ppo_recurrent.py +++ b/sb3_contrib/ppo_recurrent/ppo_recurrent.py @@ -47,6 +47,7 @@ class RecurrentPPO(OnPolicyAlgorithm): This is a parameter specific to the OpenAI implementation. If None is passed (default), no clipping will be done on the value function. IMPORTANT: this clipping depends on the reward scaling. + :param normalize_advantage: Whether to normalize or not the advantage :param ent_coef: Entropy coefficient for the loss calculation :param vf_coef: Value function coefficient for the loss calculation :param max_grad_norm: The maximum value for the gradient clipping @@ -83,6 +84,7 @@ def __init__( gae_lambda: float = 0.95, clip_range: Union[float, Schedule] = 0.2, clip_range_vf: Union[None, float, Schedule] = None, + normalize_advantage: bool = True, ent_coef: float = 0.0, vf_coef: float = 0.5, max_grad_norm: float = 0.5, @@ -128,6 +130,7 @@ def __init__( self.n_epochs = n_epochs self.clip_range = clip_range self.clip_range_vf = clip_range_vf + self.normalize_advantage = normalize_advantage self.target_kl = target_kl self._last_lstm_states = None @@ -391,7 +394,8 @@ def train(self) -> None: values = values.flatten() # Normalize advantage advantages = rollout_data.advantages - advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) + if self.normalize_advantage: + advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) # ratio between old and new policy, should be one at the first iteration ratio = th.exp(log_prob - rollout_data.old_log_prob) From 662f218d2b873dc8ffa326bd6b22b0c4b4e0b4bc Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Tue, 12 Apr 2022 23:47:30 +0200 Subject: [PATCH 37/50] Pass state shape and allow lstm kwargs --- sb3_contrib/common/recurrent/buffers.py | 29 +++++++++-------- sb3_contrib/common/recurrent/policies.py | 37 +++++++++++++++++++--- sb3_contrib/ppo_recurrent/ppo_recurrent.py | 9 +++--- 3 files changed, 51 insertions(+), 24 deletions(-) diff --git a/sb3_contrib/common/recurrent/buffers.py b/sb3_contrib/common/recurrent/buffers.py index 5feff9ff..2d452697 100644 --- a/sb3_contrib/common/recurrent/buffers.py +++ b/sb3_contrib/common/recurrent/buffers.py @@ -20,7 +20,7 @@ class RecurrentRolloutBuffer(RolloutBuffer): :param buffer_size: Max number of element in the buffer :param observation_space: Observation space :param action_space: Action space - :param lstm_states: Dummy LSTM states to have the correct shapes when reseting the buffer + :param hidden_state_shape: Shape of the buffer that will collect lstm states :param device: PyTorch device :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator Equivalent to classic advantage when set to 1. @@ -33,23 +33,23 @@ def __init__( buffer_size: int, observation_space: spaces.Space, action_space: spaces.Space, - lstm_states: Tuple[np.ndarray, np.ndarray], + hidden_state_shape: Tuple[int, int, int, int], device: Union[th.device, str] = "cpu", gae_lambda: float = 1, gamma: float = 0.99, n_envs: int = 1, ): - self.lstm_states = lstm_states + self.hidden_state_shape = hidden_state_shape self.initial_lstm_states = None self.starts, self.ends = None, None super().__init__(buffer_size, observation_space, action_space, device, gae_lambda, gamma, n_envs) def reset(self): super().reset() - self.hidden_states_pi = np.zeros_like(self.lstm_states[0]) - self.cell_states_pi = np.zeros_like(self.lstm_states[1]) - self.hidden_states_vf = np.zeros_like(self.lstm_states[0]) - self.cell_states_vf = np.zeros_like(self.lstm_states[1]) + self.hidden_states_pi = np.zeros(self.hidden_state_shape, dtype=np.float32) + self.cell_states_pi = np.zeros(self.hidden_state_shape, dtype=np.float32) + self.hidden_states_vf = np.zeros(self.hidden_state_shape, dtype=np.float32) + self.cell_states_vf = np.zeros(self.hidden_state_shape, dtype=np.float32) def add(self, *args, lstm_states: RNNStates, **kwargs) -> None: """ @@ -129,6 +129,7 @@ def _get_samples( self.ends = np.concatenate([(self.starts - 1)[1:], np.array([len(batch_inds)])]) n_layers = self.hidden_states_pi.shape[1] + # Number of sequences n_seq = len(self.starts) max_length = self.pad(self.actions[batch_inds]).shape[0] # TODO: output mask to not backpropagate everywhere @@ -176,7 +177,7 @@ class RecurrentDictRolloutBuffer(DictRolloutBuffer): :param buffer_size: Max number of element in the buffer :param observation_space: Observation space :param action_space: Action space - :param lstm_states: Dummy LSTM states to have the correct shapes when reseting the buffer + :param hidden_state_shape: Shape of the buffer that will collect lstm states :param device: PyTorch device :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator Equivalent to classic advantage when set to 1. @@ -189,23 +190,23 @@ def __init__( buffer_size: int, observation_space: spaces.Space, action_space: spaces.Space, - lstm_states: Tuple[np.ndarray, np.ndarray], + hidden_state_shape: Tuple[int, int, int, int], device: Union[th.device, str] = "cpu", gae_lambda: float = 1, gamma: float = 0.99, n_envs: int = 1, ): - self.lstm_states = lstm_states + self.hidden_state_shape = hidden_state_shape self.initial_lstm_states = None self.starts, self.ends = None, None super().__init__(buffer_size, observation_space, action_space, device, gae_lambda, gamma, n_envs=n_envs) def reset(self): super().reset() - self.hidden_states_pi = np.zeros_like(self.lstm_states[0]) - self.cell_states_pi = np.zeros_like(self.lstm_states[1]) - self.hidden_states_vf = np.zeros_like(self.lstm_states[0]) - self.cell_states_vf = np.zeros_like(self.lstm_states[1]) + self.hidden_states_pi = np.zeros(self.hidden_state_shape, dtype=np.float32) + self.cell_states_pi = np.zeros(self.hidden_state_shape, dtype=np.float32) + self.hidden_states_vf = np.zeros(self.hidden_state_shape, dtype=np.float32) + self.cell_states_vf = np.zeros(self.hidden_state_shape, dtype=np.float32) def add(self, *args, lstm_states: RNNStates, **kwargs) -> None: """ diff --git a/sb3_contrib/common/recurrent/policies.py b/sb3_contrib/common/recurrent/policies.py index 711767d2..2bf1c846 100644 --- a/sb3_contrib/common/recurrent/policies.py +++ b/sb3_contrib/common/recurrent/policies.py @@ -56,6 +56,8 @@ class RecurrentActorCriticPolicy(ActorCriticPolicy): :param shared_lstm: Whether the LSTM is shared between the actor and the critic. By default, only the actor has a recurrent network. :param enable_critic_lstm: Use a seperate LSTM for the critic. + :param lstm_kwargs: Additional keyword arguments to pass the the LSTM + constructor. """ def __init__( @@ -81,6 +83,7 @@ def __init__( n_lstm_layers: int = 1, shared_lstm: bool = False, enable_critic_lstm: bool = False, + lstm_kwargs: Optional[Dict[str, Any]] = None, ): self.lstm_output_dim = lstm_hidden_size super().__init__( @@ -103,21 +106,38 @@ def __init__( optimizer_kwargs, ) + self.lstm_kwargs = lstm_kwargs or {} self.shared_lstm = shared_lstm self.enable_critic_lstm = enable_critic_lstm - self.lstm_actor = nn.LSTM(self.features_dim, lstm_hidden_size, num_layers=n_lstm_layers) - self.lstm_shape = (n_lstm_layers, 1, lstm_hidden_size) + self.lstm_actor = nn.LSTM( + self.features_dim, + lstm_hidden_size, + num_layers=n_lstm_layers, + **self.lstm_kwargs, + ) + # For the predict() method, to initialize hidden states + # (n_lstm_layers, batch_size, lstm_hidden_size) + self.lstm_hidden_state_shape = (n_lstm_layers, 1, lstm_hidden_size) self.critic = None self.lstm_critic = None assert not ( self.shared_lstm and self.enable_critic_lstm ), "You must choose between shared LSTM, seperate or no LSTM for the critic" + # No LSTM for the critic, we still need to convert + # output of features extractor to the correct size + # (size of the output of the actor lstm) if not (self.shared_lstm or self.enable_critic_lstm): self.critic = nn.Linear(self.features_dim, lstm_hidden_size) + # Use a separate LSTM for the critic if self.enable_critic_lstm: - self.lstm_critic = nn.LSTM(self.features_dim, lstm_hidden_size, num_layers=n_lstm_layers) + self.lstm_critic = nn.LSTM( + self.features_dim, + lstm_hidden_size, + num_layers=n_lstm_layers, + **self.lstm_kwargs, + ) # Setup optimizer with initial learning rate self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs) @@ -348,7 +368,7 @@ def predict( # state : (n_layers, n_envs, dim) if state is None: # Initialize hidden states to zeros - state = np.concatenate([np.zeros(self.lstm_shape) for _ in range(n_envs)], axis=1) + state = np.concatenate([np.zeros(self.lstm_hidden_state_shape) for _ in range(n_envs)], axis=1) state = (state, state) if episode_start is None: @@ -419,7 +439,8 @@ class RecurrentActorCriticCnnPolicy(RecurrentActorCriticPolicy): :param shared_lstm: Whether the LSTM is shared between the actor and the critic. By default, only the actor has a recurrent network. :param enable_critic_lstm: Use a seperate LSTM for the critic. - + :param lstm_kwargs: Additional keyword arguments to pass the the LSTM + constructor. """ def __init__( @@ -444,6 +465,7 @@ def __init__( lstm_hidden_size: int = 256, n_lstm_layers: int = 1, enable_critic_lstm: bool = False, + lstm_kwargs: Optional[Dict[str, Any]] = None, ): super().__init__( observation_space, @@ -466,6 +488,7 @@ def __init__( lstm_hidden_size, n_lstm_layers, enable_critic_lstm, + lstm_kwargs, ) @@ -506,6 +529,8 @@ class RecurrentMultiInputActorCriticPolicy(RecurrentActorCriticPolicy): :param shared_lstm: Whether the LSTM is shared between the actor and the critic. By default, only the actor has a recurrent network. :param enable_critic_lstm: Use a seperate LSTM for the critic. + :param lstm_kwargs: Additional keyword arguments to pass the the LSTM + constructor. """ def __init__( @@ -530,6 +555,7 @@ def __init__( lstm_hidden_size: int = 256, n_lstm_layers: int = 1, enable_critic_lstm: bool = False, + lstm_kwargs: Optional[Dict[str, Any]] = None, ): super().__init__( observation_space, @@ -552,4 +578,5 @@ def __init__( lstm_hidden_size, n_lstm_layers, enable_critic_lstm, + lstm_kwargs, ) diff --git a/sb3_contrib/ppo_recurrent/ppo_recurrent.py b/sb3_contrib/ppo_recurrent/ppo_recurrent.py index 1700d29c..628b6d8b 100644 --- a/sb3_contrib/ppo_recurrent/ppo_recurrent.py +++ b/sb3_contrib/ppo_recurrent/ppo_recurrent.py @@ -159,11 +159,8 @@ def _setup_model(self) -> None: if not isinstance(self.policy, RecurrentActorCriticPolicy): raise ValueError("Policy must subclass RecurrentActorCriticPolicy") - hidden_state_shape = (self.n_steps, lstm.num_layers, self.n_envs, lstm.hidden_size) - lstm_states = (np.zeros(hidden_state_shape, dtype=np.float32), np.zeros(hidden_state_shape, dtype=np.float32)) - single_hidden_state_shape = (lstm.num_layers, self.n_envs, lstm.hidden_size) - # hidden states for actor and critic + # hidden and cell states for actor and critic self._last_lstm_states = RNNStates( ( th.zeros(single_hidden_state_shape).to(self.device), @@ -175,11 +172,13 @@ def _setup_model(self) -> None: ), ) + hidden_state_buffer_shape = (self.n_steps, lstm.num_layers, self.n_envs, lstm.hidden_size) + self.rollout_buffer = buffer_cls( self.n_steps, self.observation_space, self.action_space, - lstm_states, + hidden_state_buffer_shape, self.device, gamma=self.gamma, gae_lambda=self.gae_lambda, From fd068500fcf9c66695c94def1e7d8c6044db40ad Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Tue, 12 Apr 2022 23:53:59 +0200 Subject: [PATCH 38/50] Update tests --- tests/test_lstm.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/tests/test_lstm.py b/tests/test_lstm.py index d8a76f86..8cc5271d 100644 --- a/tests/test_lstm.py +++ b/tests/test_lstm.py @@ -68,7 +68,18 @@ def test_cnn(): model.learn(total_timesteps=32) -@pytest.mark.parametrize("policy_kwargs", [{}, dict(shared_lstm=True), dict(enable_critic_lstm=True, lstm_hidden_size=4)]) +@pytest.mark.parametrize( + "policy_kwargs", + [ + {}, + dict(shared_lstm=True), + dict( + enable_critic_lstm=True, + lstm_hidden_size=4, + lstm_kwargs=dict(dropout=0.5), + ), + ], +) def test_policy_kwargs(policy_kwargs): model = RecurrentPPO( "MlpLstmPolicy", From f5e9b34b85e6aa56256aee136050fdae4f7e2350 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Wed, 13 Apr 2022 00:28:41 +0200 Subject: [PATCH 39/50] Add masking for padded sequences --- sb3_contrib/common/recurrent/buffers.py | 60 +++++++++++++++----- sb3_contrib/common/recurrent/type_aliases.py | 2 + sb3_contrib/ppo_recurrent/ppo_recurrent.py | 7 ++- 3 files changed, 54 insertions(+), 15 deletions(-) diff --git a/sb3_contrib/common/recurrent/buffers.py b/sb3_contrib/common/recurrent/buffers.py index 2d452697..08192d5a 100644 --- a/sb3_contrib/common/recurrent/buffers.py +++ b/sb3_contrib/common/recurrent/buffers.py @@ -111,9 +111,20 @@ def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentRolloutBuf yield self._get_samples(batch_inds, env_change) start_idx += batch_size - def pad(self, tensor: np.ndarray) -> th.Tensor: + def pad(self, tensor: np.ndarray, padding_value: float = 0.0) -> th.Tensor: seq = [self.to_torch(tensor[start : end + 1]) for start, end in zip(self.starts, self.ends)] - return th.nn.utils.rnn.pad_sequence(seq) + return th.nn.utils.rnn.pad_sequence(seq, padding_value=padding_value) + + def _pad_and_flatten(self, tensor: np.ndarray, padding_value: float = 0.0) -> th.Tensor: + """ + Pad and flatten the sequences of scalar values, + while keeping the sequence order. + From (max_length, n_seq, 1) to (n_seq, max_length, 1) -> (n_seq * max_length,) + + :param tensor: + :return: + """ + return self.pad(tensor, padding_value).swapaxes(0, 1).flatten() def _get_samples( self, @@ -147,15 +158,20 @@ def _get_samples( lstm_states_pi = (self.to_torch(lstm_states_pi[0]), self.to_torch(lstm_states_pi[1])) lstm_states_vf = (self.to_torch(lstm_states_vf[0]), self.to_torch(lstm_states_vf[1])) + # Prime number, unlikely to happen + padding_value = 6739122773 return RecurrentRolloutBufferSamples( + # (max_length, n_seq, obs_dim) to (n_seq, max_length, obs_dim) -> (n_seq * max_length, obs_dim) observations=self.pad(self.observations[batch_inds]).swapaxes(0, 1).reshape((padded_batch_size,) + self.obs_shape), actions=self.pad(self.actions[batch_inds]).swapaxes(0, 1).reshape((padded_batch_size,) + self.actions.shape[1:]), - old_values=self.pad(self.values[batch_inds]).swapaxes(0, 1).flatten(), - old_log_prob=self.pad(self.log_probs[batch_inds]).swapaxes(0, 1).flatten(), - advantages=self.pad(self.advantages[batch_inds]).swapaxes(0, 1).flatten(), - returns=self.pad(self.returns[batch_inds]).swapaxes(0, 1).flatten(), + old_values=self._pad_and_flatten(self.values[batch_inds]), + old_log_prob=self._pad_and_flatten(self.log_probs[batch_inds]), + advantages=self._pad_and_flatten(self.advantages[batch_inds]), + returns=self._pad_and_flatten(self.returns[batch_inds]), lstm_states=RNNStates(lstm_states_pi, lstm_states_vf), - episode_starts=self.pad(self.episode_starts[batch_inds]).swapaxes(0, 1).flatten(), + episode_starts=self._pad_and_flatten(self.episode_starts[batch_inds]), + # Hack to detect padding + mask=self._pad_and_flatten(self.returns[batch_inds], padding_value) != padding_value, ) @@ -268,9 +284,21 @@ def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentDictRollou yield self._get_samples(batch_inds, env_change) start_idx += batch_size - def pad(self, tensor: np.ndarray) -> th.Tensor: + def pad(self, tensor: np.ndarray, padding_value: float = 0.0) -> th.Tensor: + # Create sequences given start and end seq = [self.to_torch(tensor[start : end + 1]) for start, end in zip(self.starts, self.ends)] - return th.nn.utils.rnn.pad_sequence(seq) + return th.nn.utils.rnn.pad_sequence(seq, padding_value=padding_value) + + def _pad_and_flatten(self, tensor: np.ndarray, padding_value: float = 0.0) -> th.Tensor: + """ + Pad and flatten the sequences of scalar values, + while keeping the sequence order. + From (max_length, n_seq, 1) to (n_seq, max_length, 1) -> (n_seq * max_length,) + + :param tensor: + :return: + """ + return self.pad(tensor, padding_value).swapaxes(0, 1).flatten() def _get_samples( self, @@ -308,13 +336,17 @@ def _get_samples( key: obs.swapaxes(0, 1).reshape((padded_batch_size,) + self.obs_shape[key]) for (key, obs) in observations.items() } + # Prime number, unlikely to happen + padding_value = 6739122773 return RecurrentDictRolloutBufferSamples( observations=observations, actions=self.pad(self.actions[batch_inds]).swapaxes(0, 1).reshape((padded_batch_size,) + self.actions.shape[1:]), - old_values=self.pad(self.values[batch_inds]).swapaxes(0, 1).flatten(), - old_log_prob=self.pad(self.log_probs[batch_inds]).swapaxes(0, 1).flatten(), - advantages=self.pad(self.advantages[batch_inds]).swapaxes(0, 1).flatten(), - returns=self.pad(self.returns[batch_inds]).swapaxes(0, 1).flatten(), + old_values=self._pad_and_flatten(self.values[batch_inds]), + old_log_prob=self._pad_and_flatten(self.log_probs[batch_inds]), + advantages=self._pad_and_flatten(self.advantages[batch_inds]), + returns=self._pad_and_flatten(self.returns[batch_inds]), lstm_states=RNNStates(lstm_states_pi, lstm_states_vf), - episode_starts=self.pad(self.episode_starts[batch_inds]).swapaxes(0, 1).flatten(), + episode_starts=self._pad_and_flatten(self.episode_starts[batch_inds]), + # Hack to detect padding + mask=self._pad_and_flatten(self.returns[batch_inds], padding_value) != padding_value, ) diff --git a/sb3_contrib/common/recurrent/type_aliases.py b/sb3_contrib/common/recurrent/type_aliases.py index 0bf019ca..1ae9a087 100644 --- a/sb3_contrib/common/recurrent/type_aliases.py +++ b/sb3_contrib/common/recurrent/type_aliases.py @@ -18,6 +18,7 @@ class RecurrentRolloutBufferSamples(NamedTuple): returns: th.Tensor lstm_states: RNNStates episode_starts: th.Tensor + mask: th.Tensor class RecurrentDictRolloutBufferSamples(RecurrentRolloutBufferSamples): @@ -29,3 +30,4 @@ class RecurrentDictRolloutBufferSamples(RecurrentRolloutBufferSamples): returns: th.Tensor lstm_states: RNNStates episode_starts: th.Tensor + mask: th.Tensor diff --git a/sb3_contrib/ppo_recurrent/ppo_recurrent.py b/sb3_contrib/ppo_recurrent/ppo_recurrent.py index 628b6d8b..67911d56 100644 --- a/sb3_contrib/ppo_recurrent/ppo_recurrent.py +++ b/sb3_contrib/ppo_recurrent/ppo_recurrent.py @@ -402,6 +402,9 @@ def train(self) -> None: # clipped surrogate loss policy_loss_1 = advantages * ratio policy_loss_2 = advantages * th.clamp(ratio, 1 - clip_range, 1 + clip_range) + # Mask padded sequences + policy_loss_1 = policy_loss_1 * rollout_data.mask + policy_loss_2 = policy_loss_2 * rollout_data.mask policy_loss = -th.min(policy_loss_1, policy_loss_2).mean() # Logging @@ -419,7 +422,9 @@ def train(self) -> None: values - rollout_data.old_values, -clip_range_vf, clip_range_vf ) # Value loss using the TD(gae_lambda) target - value_loss = F.mse_loss(rollout_data.returns, values_pred) + # Mask padded sequences + value_loss = th.mean(((rollout_data.returns - values_pred) * rollout_data.mask) ** 2) + value_losses.append(value_loss.item()) # Entropy loss favor exploration From 1cd27da166732a9e6f2f53a1984eb02f3f936c9b Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Wed, 13 Apr 2022 01:45:53 +0200 Subject: [PATCH 40/50] Update default in perf test --- tests/test_lstm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_lstm.py b/tests/test_lstm.py index 8cc5271d..d88129e4 100644 --- a/tests/test_lstm.py +++ b/tests/test_lstm.py @@ -153,7 +153,7 @@ def make_env(): n_epochs=10, max_grad_norm=1, gae_lambda=0.98, - policy_kwargs=dict(net_arch=[dict(vf=[64])], ortho_init=False), + policy_kwargs=dict(net_arch=[dict(vf=[64])], ortho_init=False, enable_critic_lstm=True), ) model.learn(total_timesteps=50_000, callback=eval_callback) From c52959b1ab68eca67d72b2835397b296118e754a Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Fri, 15 Apr 2022 20:25:19 +0200 Subject: [PATCH 41/50] Remove TODO, mask is now working --- sb3_contrib/common/recurrent/buffers.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/sb3_contrib/common/recurrent/buffers.py b/sb3_contrib/common/recurrent/buffers.py index 08192d5a..feaacaf9 100644 --- a/sb3_contrib/common/recurrent/buffers.py +++ b/sb3_contrib/common/recurrent/buffers.py @@ -143,7 +143,6 @@ def _get_samples( # Number of sequences n_seq = len(self.starts) max_length = self.pad(self.actions[batch_inds]).shape[0] - # TODO: output mask to not backpropagate everywhere padded_batch_size = n_seq * max_length lstm_states_pi = ( # (n_steps, n_layers, n_envs, dim) -> (n_layers, n_seq, dim) @@ -316,7 +315,6 @@ def _get_samples( n_layers = self.hidden_states_pi.shape[1] n_seq = len(self.starts) max_length = self.pad(self.actions[batch_inds]).shape[0] - # TODO: output mask to not backpropagate everywhere padded_batch_size = n_seq * max_length lstm_states_pi = ( # (n_steps, n_layers, n_envs, dim) -> (n_layers, n_seq, dim) From 673d23a47ef563894865b392a20596de79ef47d9 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Sun, 1 May 2022 20:51:10 +0200 Subject: [PATCH 42/50] Add helper to remove duplicated code, remove hack for padding --- sb3_contrib/common/recurrent/buffers.py | 173 ++++++++++++--------- sb3_contrib/ppo_recurrent/ppo_recurrent.py | 1 - 2 files changed, 100 insertions(+), 74 deletions(-) diff --git a/sb3_contrib/common/recurrent/buffers.py b/sb3_contrib/common/recurrent/buffers.py index feaacaf9..b8538346 100644 --- a/sb3_contrib/common/recurrent/buffers.py +++ b/sb3_contrib/common/recurrent/buffers.py @@ -1,3 +1,4 @@ +from functools import partial from typing import Generator, Optional, Tuple, Union import numpy as np @@ -13,6 +14,69 @@ ) +def pad( + seq_start_indices: np.ndarray, + seq_end_indices: np.ndarray, + device: th.device, + tensor: np.ndarray, + padding_value: float = 0.0, +) -> th.Tensor: + """ + Chunk sequences and pad them to have constant dimensions. + + :param seq_start_indices: Indices of the transitions that start a sequence + :param seq_end_indices: Indices of the transitions that end a sequence + :param tensor: Tensor of shape (batch_size, *tensor_shape) + :return: + """ + # Create sequences given start and end + seq = [th.tensor(tensor[start : end + 1], device=device) for start, end in zip(seq_start_indices, seq_end_indices)] + return th.nn.utils.rnn.pad_sequence(seq, padding_value=padding_value) + + +def pad_and_flatten( + seq_start_indices: np.ndarray, + seq_end_indices: np.ndarray, + device: th.device, + tensor: np.ndarray, + padding_value: float = 0.0, +) -> th.Tensor: + """ + Pad and flatten the sequences of scalar values, + while keeping the sequence order. + From (max_length, n_seq, 1) to (n_seq, max_length, 1) -> (n_seq * max_length,) + + :param seq_start_indices: Indices of the transitions that start a sequence + :param seq_end_indices: Indices of the transitions that end a sequence + :param device: PyTorch device (cpu, gpu, ...) + :param tensor: + :return: + """ + return pad(seq_start_indices, seq_end_indices, device, tensor, padding_value).swapaxes(0, 1).flatten() + + +def create_sequencers( + episode_starts: np.ndarray, + env_change: np.ndarray, + device: th.device, +): + # Create sequence if env changes too + seq_start = np.logical_or(episode_starts, env_change).flatten() + # First index is always the beginning of a sequence + seq_start[0] = True + # Retrieve indices of sequence starts + seq_start_indices = np.where(seq_start == True)[0] # noqa: E712 + # End of sequence are just before sequence starts + # Last index is also always end of a sequence + seq_end_indices = np.concatenate([(seq_start_indices - 1)[1:], np.array([len(episode_starts)])]) + + # Create padding method for this minibatch + # to avoid repeating arguments (seq_start_indices, seq_end_indices) + local_pad = partial(pad, seq_start_indices, seq_end_indices, device) + local_pad_and_flatten = partial(pad_and_flatten, seq_start_indices, seq_end_indices, device) + return seq_start_indices, local_pad, local_pad_and_flatten + + class RecurrentRolloutBuffer(RolloutBuffer): """ Rollout buffer that also stores the invalid action masks associated with each observation. @@ -41,7 +105,7 @@ def __init__( ): self.hidden_state_shape = hidden_state_shape self.initial_lstm_states = None - self.starts, self.ends = None, None + self.seq_start_indices, self.seq_end_indices = None, None super().__init__(buffer_size, observation_space, action_space, device, gae_lambda, gamma, n_envs) def reset(self): @@ -111,66 +175,48 @@ def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentRolloutBuf yield self._get_samples(batch_inds, env_change) start_idx += batch_size - def pad(self, tensor: np.ndarray, padding_value: float = 0.0) -> th.Tensor: - seq = [self.to_torch(tensor[start : end + 1]) for start, end in zip(self.starts, self.ends)] - return th.nn.utils.rnn.pad_sequence(seq, padding_value=padding_value) - - def _pad_and_flatten(self, tensor: np.ndarray, padding_value: float = 0.0) -> th.Tensor: - """ - Pad and flatten the sequences of scalar values, - while keeping the sequence order. - From (max_length, n_seq, 1) to (n_seq, max_length, 1) -> (n_seq * max_length,) - - :param tensor: - :return: - """ - return self.pad(tensor, padding_value).swapaxes(0, 1).flatten() - def _get_samples( self, batch_inds: np.ndarray, env_change: np.ndarray, env: Optional[VecNormalize] = None, ) -> RecurrentRolloutBufferSamples: - # Create sequence if env change too - seq_start = np.logical_or(self.episode_starts[batch_inds], env_change[batch_inds]).flatten() - # First index is always the beginning of a sequence - seq_start[0] = True - self.starts = np.where(seq_start == True)[0] # noqa: E712 - self.ends = np.concatenate([(self.starts - 1)[1:], np.array([len(batch_inds)])]) + # Retrieve sequence starts and utility function + self.seq_start_indices, self.pad, self.pad_and_flatten = create_sequencers( + self.episode_starts[batch_inds], env_change[batch_inds], self.device + ) n_layers = self.hidden_states_pi.shape[1] # Number of sequences - n_seq = len(self.starts) + n_seq = len(self.seq_start_indices) max_length = self.pad(self.actions[batch_inds]).shape[0] padded_batch_size = n_seq * max_length + # We retrieve the lstm hidden states that will allow + # to properly initialize the LSTM at the beginning of each sequence lstm_states_pi = ( # (n_steps, n_layers, n_envs, dim) -> (n_layers, n_seq, dim) - self.hidden_states_pi[batch_inds][seq_start == True].reshape(n_layers, n_seq, -1), # noqa: E712 - self.cell_states_pi[batch_inds][seq_start == True].reshape(n_layers, n_seq, -1), # noqa: E712 + self.hidden_states_pi[batch_inds][self.seq_start_indices].reshape(n_layers, n_seq, -1), + self.cell_states_pi[batch_inds][self.seq_start_indices].reshape(n_layers, n_seq, -1), ) lstm_states_vf = ( # (n_steps, n_layers, n_envs, dim) -> (n_layers, n_seq, dim) - self.hidden_states_vf[batch_inds][seq_start == True].reshape(n_layers, n_seq, -1), # noqa: E712 - self.cell_states_vf[batch_inds][seq_start == True].reshape(n_layers, n_seq, -1), # noqa: E712 + self.hidden_states_vf[batch_inds][self.seq_start_indices].reshape(n_layers, n_seq, -1), + self.cell_states_vf[batch_inds][self.seq_start_indices].reshape(n_layers, n_seq, -1), ) lstm_states_pi = (self.to_torch(lstm_states_pi[0]), self.to_torch(lstm_states_pi[1])) lstm_states_vf = (self.to_torch(lstm_states_vf[0]), self.to_torch(lstm_states_vf[1])) - # Prime number, unlikely to happen - padding_value = 6739122773 return RecurrentRolloutBufferSamples( # (max_length, n_seq, obs_dim) to (n_seq, max_length, obs_dim) -> (n_seq * max_length, obs_dim) observations=self.pad(self.observations[batch_inds]).swapaxes(0, 1).reshape((padded_batch_size,) + self.obs_shape), actions=self.pad(self.actions[batch_inds]).swapaxes(0, 1).reshape((padded_batch_size,) + self.actions.shape[1:]), - old_values=self._pad_and_flatten(self.values[batch_inds]), - old_log_prob=self._pad_and_flatten(self.log_probs[batch_inds]), - advantages=self._pad_and_flatten(self.advantages[batch_inds]), - returns=self._pad_and_flatten(self.returns[batch_inds]), + old_values=self.pad_and_flatten(self.values[batch_inds]), + old_log_prob=self.pad_and_flatten(self.log_probs[batch_inds]), + advantages=self.pad_and_flatten(self.advantages[batch_inds]), + returns=self.pad_and_flatten(self.returns[batch_inds]), lstm_states=RNNStates(lstm_states_pi, lstm_states_vf), - episode_starts=self._pad_and_flatten(self.episode_starts[batch_inds]), - # Hack to detect padding - mask=self._pad_and_flatten(self.returns[batch_inds], padding_value) != padding_value, + episode_starts=self.pad_and_flatten(self.episode_starts[batch_inds]), + mask=self.pad_and_flatten(np.ones_like(self.returns[batch_inds])), ) @@ -213,7 +259,7 @@ def __init__( ): self.hidden_state_shape = hidden_state_shape self.initial_lstm_states = None - self.starts, self.ends = None, None + self.seq_start_indices, self.seq_end_indices = None, None super().__init__(buffer_size, observation_space, action_space, device, gae_lambda, gamma, n_envs=n_envs) def reset(self): @@ -283,48 +329,32 @@ def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentDictRollou yield self._get_samples(batch_inds, env_change) start_idx += batch_size - def pad(self, tensor: np.ndarray, padding_value: float = 0.0) -> th.Tensor: - # Create sequences given start and end - seq = [self.to_torch(tensor[start : end + 1]) for start, end in zip(self.starts, self.ends)] - return th.nn.utils.rnn.pad_sequence(seq, padding_value=padding_value) - - def _pad_and_flatten(self, tensor: np.ndarray, padding_value: float = 0.0) -> th.Tensor: - """ - Pad and flatten the sequences of scalar values, - while keeping the sequence order. - From (max_length, n_seq, 1) to (n_seq, max_length, 1) -> (n_seq * max_length,) - - :param tensor: - :return: - """ - return self.pad(tensor, padding_value).swapaxes(0, 1).flatten() - def _get_samples( self, batch_inds: np.ndarray, env_change: np.ndarray, env: Optional[VecNormalize] = None, ) -> RecurrentDictRolloutBufferSamples: - # Create sequence if env change too - seq_start = np.logical_or(self.episode_starts[batch_inds], env_change[batch_inds]).flatten() - # First index is always the beginning of a sequence - seq_start[0] = True - self.starts = np.where(seq_start == True)[0] # noqa: E712 - self.ends = np.concatenate([(self.starts - 1)[1:], np.array([len(batch_inds)])]) + # Retrieve sequence starts and utility function + self.seq_start_indices, self.pad, self.pad_and_flatten = create_sequencers( + self.episode_starts[batch_inds], env_change[batch_inds], self.device + ) n_layers = self.hidden_states_pi.shape[1] - n_seq = len(self.starts) + n_seq = len(self.seq_start_indices) max_length = self.pad(self.actions[batch_inds]).shape[0] padded_batch_size = n_seq * max_length + # We retrieve the lstm hidden states that will allow + # to properly initialize the LSTM at the beginning of each sequence lstm_states_pi = ( # (n_steps, n_layers, n_envs, dim) -> (n_layers, n_seq, dim) - self.hidden_states_pi[batch_inds][seq_start == True].reshape(n_layers, n_seq, -1), # noqa: E712 - self.cell_states_pi[batch_inds][seq_start == True].reshape(n_layers, n_seq, -1), # noqa: E712 + self.hidden_states_pi[batch_inds][self.seq_start_indices].reshape(n_layers, n_seq, -1), + self.cell_states_pi[batch_inds][self.seq_start_indices].reshape(n_layers, n_seq, -1), ) lstm_states_vf = ( # (n_steps, n_layers, n_envs, dim) -> (n_layers, n_seq, dim) - self.hidden_states_vf[batch_inds][seq_start == True].reshape(n_layers, n_seq, -1), # noqa: E712 - self.cell_states_vf[batch_inds][seq_start == True].reshape(n_layers, n_seq, -1), # noqa: E712 + self.hidden_states_vf[batch_inds][self.seq_start_indices].reshape(n_layers, n_seq, -1), + self.cell_states_vf[batch_inds][self.seq_start_indices].reshape(n_layers, n_seq, -1), ) lstm_states_pi = (self.to_torch(lstm_states_pi[0]), self.to_torch(lstm_states_pi[1])) lstm_states_vf = (self.to_torch(lstm_states_vf[0]), self.to_torch(lstm_states_vf[1])) @@ -334,17 +364,14 @@ def _get_samples( key: obs.swapaxes(0, 1).reshape((padded_batch_size,) + self.obs_shape[key]) for (key, obs) in observations.items() } - # Prime number, unlikely to happen - padding_value = 6739122773 return RecurrentDictRolloutBufferSamples( observations=observations, actions=self.pad(self.actions[batch_inds]).swapaxes(0, 1).reshape((padded_batch_size,) + self.actions.shape[1:]), - old_values=self._pad_and_flatten(self.values[batch_inds]), - old_log_prob=self._pad_and_flatten(self.log_probs[batch_inds]), - advantages=self._pad_and_flatten(self.advantages[batch_inds]), - returns=self._pad_and_flatten(self.returns[batch_inds]), + old_values=self.pad_and_flatten(self.values[batch_inds]), + old_log_prob=self.pad_and_flatten(self.log_probs[batch_inds]), + advantages=self.pad_and_flatten(self.advantages[batch_inds]), + returns=self.pad_and_flatten(self.returns[batch_inds]), lstm_states=RNNStates(lstm_states_pi, lstm_states_vf), - episode_starts=self._pad_and_flatten(self.episode_starts[batch_inds]), - # Hack to detect padding - mask=self._pad_and_flatten(self.returns[batch_inds], padding_value) != padding_value, + episode_starts=self.pad_and_flatten(self.episode_starts[batch_inds]), + mask=self.pad_and_flatten(np.ones_like(self.returns[batch_inds])), ) diff --git a/sb3_contrib/ppo_recurrent/ppo_recurrent.py b/sb3_contrib/ppo_recurrent/ppo_recurrent.py index 67911d56..7d5fc53e 100644 --- a/sb3_contrib/ppo_recurrent/ppo_recurrent.py +++ b/sb3_contrib/ppo_recurrent/ppo_recurrent.py @@ -13,7 +13,6 @@ from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule from stable_baselines3.common.utils import explained_variance, get_schedule_fn, obs_as_tensor, safe_mean from stable_baselines3.common.vec_env import VecEnv -from torch.nn import functional as F from sb3_contrib.common.recurrent.buffers import RecurrentDictRolloutBuffer, RecurrentRolloutBuffer from sb3_contrib.common.recurrent.policies import RecurrentActorCriticPolicy From e271d03fe660a59c01ca0ad2808551fc2462fefe Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Sun, 8 May 2022 15:23:59 +0200 Subject: [PATCH 43/50] Enable LSTM critic and raise threshold for cartpole with no vel --- sb3_contrib/common/recurrent/policies.py | 6 +++--- tests/test_lstm.py | 17 +++++++++++------ 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/sb3_contrib/common/recurrent/policies.py b/sb3_contrib/common/recurrent/policies.py index 2bf1c846..7bd27e9e 100644 --- a/sb3_contrib/common/recurrent/policies.py +++ b/sb3_contrib/common/recurrent/policies.py @@ -82,7 +82,7 @@ def __init__( lstm_hidden_size: int = 256, n_lstm_layers: int = 1, shared_lstm: bool = False, - enable_critic_lstm: bool = False, + enable_critic_lstm: bool = True, lstm_kwargs: Optional[Dict[str, Any]] = None, ): self.lstm_output_dim = lstm_hidden_size @@ -464,7 +464,7 @@ def __init__( optimizer_kwargs: Optional[Dict[str, Any]] = None, lstm_hidden_size: int = 256, n_lstm_layers: int = 1, - enable_critic_lstm: bool = False, + enable_critic_lstm: bool = True, lstm_kwargs: Optional[Dict[str, Any]] = None, ): super().__init__( @@ -554,7 +554,7 @@ def __init__( optimizer_kwargs: Optional[Dict[str, Any]] = None, lstm_hidden_size: int = 256, n_lstm_layers: int = 1, - enable_critic_lstm: bool = False, + enable_critic_lstm: bool = True, lstm_kwargs: Optional[Dict[str, Any]] = None, ): super().__init__( diff --git a/tests/test_lstm.py b/tests/test_lstm.py index d88129e4..e1c0e95d 100644 --- a/tests/test_lstm.py +++ b/tests/test_lstm.py @@ -8,6 +8,7 @@ from stable_baselines3.common.env_util import make_vec_env from stable_baselines3.common.envs import FakeImageEnv from stable_baselines3.common.evaluation import evaluate_policy +from stable_baselines3.common.vec_env import VecNormalize from sb3_contrib import RecurrentPPO @@ -134,10 +135,10 @@ def make_env(): env = TimeLimit(env, max_episode_steps=500) return env - env = make_vec_env(make_env, n_envs=8) + env = VecNormalize(make_vec_env(make_env, n_envs=8)) eval_callback = EvalCallback( - make_vec_env(make_env, n_envs=4), + VecNormalize(make_vec_env(make_env, n_envs=4), training=False, norm_reward=False), n_eval_episodes=20, eval_freq=5000 // env.num_envs, ) @@ -149,16 +150,20 @@ def make_env(): learning_rate=0.0007, verbose=1, batch_size=256, - seed=0, + seed=1, n_epochs=10, max_grad_norm=1, gae_lambda=0.98, - policy_kwargs=dict(net_arch=[dict(vf=[64])], ortho_init=False, enable_critic_lstm=True), + policy_kwargs=dict( + net_arch=[dict(vf=[64])], + lstm_hidden_size=64, + ortho_init=False, + enable_critic_lstm=True, + ), ) model.learn(total_timesteps=50_000, callback=eval_callback) # Maximum episode reward is 500. # In CartPole-v1, a non-recurrent policy can easily get >= 450. # In CartPoleNoVelEnv, a non-recurrent policy doesn't get more than ~50. - # LSTM policies can reach above 400, but it varies a lot between runs; consistently get >=150. - evaluate_policy(model, env, reward_threshold=160) + evaluate_policy(model, env, reward_threshold=450) From 73bb89cff3ee2480e21e710354ad26d215c36b0a Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Sun, 8 May 2022 15:35:27 +0200 Subject: [PATCH 44/50] Fix tests --- tests/test_lstm.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/tests/test_lstm.py b/tests/test_lstm.py index e1c0e95d..8ff81f5a 100644 --- a/tests/test_lstm.py +++ b/tests/test_lstm.py @@ -73,12 +73,17 @@ def test_cnn(): "policy_kwargs", [ {}, - dict(shared_lstm=True), + dict(shared_lstm=True, enable_critic_lstm=False), dict( enable_critic_lstm=True, lstm_hidden_size=4, lstm_kwargs=dict(dropout=0.5), ), + dict( + enable_critic_lstm=False, + lstm_hidden_size=4, + lstm_kwargs=dict(dropout=0.5), + ), ], ) def test_policy_kwargs(policy_kwargs): @@ -93,6 +98,18 @@ def test_policy_kwargs(policy_kwargs): model.learn(total_timesteps=32) +def test_check(): + policy_kwargs = dict(shared_lstm=True, enable_critic_lstm=True) + with pytest.raises(AssertionError): + model = RecurrentPPO( + "MlpLstmPolicy", + "CartPole-v1", + n_steps=16, + seed=0, + policy_kwargs=policy_kwargs, + ) + + @pytest.mark.parametrize("env", ["Pendulum-v1", "CartPole-v1"]) def test_run(env): model = RecurrentPPO( From efa61814687be03d6c72dafa33ec6747245fbc96 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Wed, 18 May 2022 23:39:12 +0200 Subject: [PATCH 45/50] Update doc and tests --- README.md | 2 +- docs/guide/examples.rst | 10 ++++++++ docs/misc/changelog.rst | 4 ++-- docs/modules/ppo_recurrent.rst | 30 ++++++++++++++++++++++-- sb3_contrib/common/recurrent/buffers.py | 25 ++++++++++++++++---- sb3_contrib/common/recurrent/policies.py | 17 ++++++++++++-- sb3_contrib/version.txt | 2 +- tests/test_deterministic.py | 17 +++++++++----- tests/test_lstm.py | 2 +- 9 files changed, 89 insertions(+), 20 deletions(-) diff --git a/README.md b/README.md index 82510b48..f54ae854 100644 --- a/README.md +++ b/README.md @@ -28,7 +28,7 @@ See documentation for the full list of included features. - [Augmented Random Search (ARS)](https://arxiv.org/abs/1803.07055) - [Quantile Regression DQN (QR-DQN)](https://arxiv.org/abs/1710.10044) - [PPO with invalid action masking (MaskablePPO)](https://arxiv.org/abs/2006.14171) -- [PPO with recurrent policy (RecurrentPPO)](https://arxiv.org/abs/1707.06347) +- [PPO with recurrent policy (RecurrentPPO aka PPO LSTM)](https://ppo-details.cleanrl.dev//2021/11/05/ppo-implementation-details/) - [Truncated Quantile Critics (TQC)](https://arxiv.org/abs/2005.04269) - [Trust Region Policy Optimization (TRPO)](https://arxiv.org/abs/1502.05477) diff --git a/docs/guide/examples.rst b/docs/guide/examples.rst index b73fe385..cd4851d0 100644 --- a/docs/guide/examples.rst +++ b/docs/guide/examples.rst @@ -77,6 +77,14 @@ RecurrentPPO Train a PPO agent with a recurrent policy on the CartPole environment. + +.. note:: + + It is particularly important to pass the ``lstm_states`` + and ``episode_start`` argument to the ``predict()`` method, + so the cell and hidden states of the LSTM are correctly updated. + + .. code-block:: python import numpy as np @@ -88,8 +96,10 @@ Train a PPO agent with a recurrent policy on the CartPole environment. env = model.get_env() obs = env.reset() + # cell and hidden state of the LSTM lstm_states = None num_envs = 1 + # Episode start signals are used to reset the lstm states episode_starts = np.ones((num_envs,), dtype=bool) while True: action, lstm_states = model.predict(obs, state=lstm_states, episode_start=episode_starts, deterministic=True) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 7d57128f..9725f814 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,7 +3,7 @@ Changelog ========== -Release 1.5.1a5 (WIP) +Release 1.5.1a6 (WIP) ------------------------------- Breaking Changes: @@ -17,7 +17,7 @@ Breaking Changes: New Features: ^^^^^^^^^^^^^ -- Added ``RecurrentPPO`` +- Added ``RecurrentPPO`` (aka PPO LSTM) Bug Fixes: ^^^^^^^^^^ diff --git a/docs/modules/ppo_recurrent.rst b/docs/modules/ppo_recurrent.rst index 22ea92d2..4c8bd69a 100644 --- a/docs/modules/ppo_recurrent.rst +++ b/docs/modules/ppo_recurrent.rst @@ -22,8 +22,7 @@ algorithm. Other than adding support for recurrent policies (LSTM here), the beh Notes ----- -.. - Paper: https://arxiv.org/abs/2006.14171 -.. - Blog post: https://costa.sh/blog-a-closer-look-at-invalid-action-masking-in-policy-gradient-algorithms.html +- Blog post: https://ppo-details.cleanrl.dev//2021/11/05/ppo-implementation-details/ Can I use? @@ -48,6 +47,12 @@ Dict ❌ ✔️ Example ------- +.. note:: + + It is particularly important to pass the ``lstm_states`` + and ``episode_start`` argument to the ``predict()`` method, + so the cell and hidden states of the LSTM are correctly updated. + .. code-block:: python @@ -69,8 +74,10 @@ Example model = RecurrentPPO.load("ppo_recurrent") obs = env.reset() + # cell and hidden state of the LSTM lstm_states = None num_envs = 1 + # Episode start signals are used to reset the lstm states episode_starts = np.ones((num_envs,), dtype=bool) while True: action, lstm_states = model.predict(obs, state=lstm_states, episode_start=episode_starts, deterministic=True) @@ -83,6 +90,16 @@ Example Results ------- +Report on environments with masked velocity (with and without framestack) can be found here: https://wandb.ai/sb3/no-vel-envs/reports/PPO-vs-RecurrentPPO-aka-PPO-LSTM-on-environments-with-masked-velocity--VmlldzoxOTI4NjE4 + +``RecurrentPPO`` was evaluated against PPO on: + +- PendulumNoVel-v1 +- LunarLanderNoVel-v2 +- CartPoleNoVel-v1 +- MountainCarContinuousNoVel-v0 +- CarRacing-v0 + How to replicate the results? ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -91,8 +108,17 @@ Clone the repo for the experiment: .. code-block:: bash git clone https://github.com/DLR-RM/rl-baselines3-zoo + cd rl-baselines3-zoo git checkout feat/recurrent-ppo + +Run the benchmark (replace ``$ENV_ID`` by the envs mentioned above): + +.. code-block:: bash + + python train.py --algo ppo_lstm --env $ENV_ID --eval-episodes 10 --eval-freq 10000 + + Parameters ---------- diff --git a/sb3_contrib/common/recurrent/buffers.py b/sb3_contrib/common/recurrent/buffers.py index b8538346..fe1ad6a1 100644 --- a/sb3_contrib/common/recurrent/buffers.py +++ b/sb3_contrib/common/recurrent/buffers.py @@ -1,5 +1,5 @@ from functools import partial -from typing import Generator, Optional, Tuple, Union +from typing import Callable, Generator, Optional, Tuple, Union import numpy as np import torch as th @@ -26,8 +26,11 @@ def pad( :param seq_start_indices: Indices of the transitions that start a sequence :param seq_end_indices: Indices of the transitions that end a sequence + :param device: PyTorch device :param tensor: Tensor of shape (batch_size, *tensor_shape) - :return: + :param padding_value: Value used to pad sequence to the same length + (zero padding by default) + :return: (n_seq * max_length, *tensor_shape) """ # Create sequences given start and end seq = [th.tensor(tensor[start : end + 1], device=device) for start, end in zip(seq_start_indices, seq_end_indices)] @@ -49,7 +52,9 @@ def pad_and_flatten( :param seq_start_indices: Indices of the transitions that start a sequence :param seq_end_indices: Indices of the transitions that end a sequence :param device: PyTorch device (cpu, gpu, ...) - :param tensor: + :param tensor: Tensor of shape (max_length, n_seq, 1) + :param padding_value: Value used to pad sequence to the same length + (zero padding by default) :return: """ return pad(seq_start_indices, seq_end_indices, device, tensor, padding_value).swapaxes(0, 1).flatten() @@ -59,7 +64,17 @@ def create_sequencers( episode_starts: np.ndarray, env_change: np.ndarray, device: th.device, -): +) -> Tuple[np.ndarray, Callable, Callable]: + """ + Create the utility function to chunk data into + sequences and pad them to create fixed size tensors. + + :param episode_starts: Indices where an episode starts + :param env_change: Indices where the data collected + come from a different env (when using multiple env for data collection) + :param device: PyTorch device + :return: + """ # Create sequence if env changes too seq_start = np.logical_or(episode_starts, env_change).flatten() # First index is always the beginning of a sequence @@ -223,7 +238,7 @@ def _get_samples( class RecurrentDictRolloutBuffer(DictRolloutBuffer): """ Dict Rollout buffer used in on-policy algorithms like A2C/PPO. - Extends the RolloutBuffer to use dictionary observations + Extends the RecurrentRolloutBuffer to use dictionary observations It corresponds to ``buffer_size`` transitions collected using the current policy. diff --git a/sb3_contrib/common/recurrent/policies.py b/sb3_contrib/common/recurrent/policies.py index 7bd27e9e..3ce577c7 100644 --- a/sb3_contrib/common/recurrent/policies.py +++ b/sb3_contrib/common/recurrent/policies.py @@ -53,8 +53,9 @@ class RecurrentActorCriticPolicy(ActorCriticPolicy): excluding the learning rate, to pass to the optimizer :param lstm_hidden_size: Number of hidden units for each LSTM layer. :param n_lstm_layers: Number of LSTM layers. - :param shared_lstm: Whether the LSTM is shared between the actor and the critic. - By default, only the actor has a recurrent network. + :param shared_lstm: Whether the LSTM is shared between the actor and the critic + (in that case, only the actor gradient is used) + By default, the actor and the critic have two separate LSTM. :param enable_critic_lstm: Use a seperate LSTM for the critic. :param lstm_kwargs: Additional keyword arguments to pass the the LSTM constructor. @@ -161,6 +162,16 @@ def _process_sequence( episode_starts: th.Tensor, lstm: nn.LSTM, ) -> Tuple[th.Tensor, th.Tensor]: + """ + Do a forward pass in the LSTM network. + + :param features: Input tensor + :param lstm_states: previous cell and hidden states of the LSTM + :param episode_starts: Indicates when a new episode starts, + in that case, we need to reset LSTM states. + :param lstm: LSTM object. + :return: LSTM output and updated LSTM states. + """ # LSTM logic # (sequence length, n_envs, features dim) (batch size = n envs) n_envs = lstm_states[0].shape[1] @@ -181,6 +192,7 @@ def _process_sequence( hidden, lstm_states = lstm( features.unsqueeze(dim=0), ( + # Reset the states at the beginning of a new episode (1.0 - episode_start).view(1, n_envs, 1) * lstm_states[0], (1.0 - episode_start).view(1, n_envs, 1) * lstm_states[1], ), @@ -218,6 +230,7 @@ def forward( latent_vf = latent_pi.detach() lstm_states_vf = (lstm_states_pi[0].detach(), lstm_states_pi[1].detach()) else: + # Critic only has a feedforward network latent_vf = self.critic(features) lstm_states_vf = lstm_states_pi diff --git a/sb3_contrib/version.txt b/sb3_contrib/version.txt index bccb8c67..1e5deca7 100644 --- a/sb3_contrib/version.txt +++ b/sb3_contrib/version.txt @@ -1 +1 @@ -1.5.1a5 +1.5.1a6 diff --git a/tests/test_deterministic.py b/tests/test_deterministic.py index 1ba7283e..458d3f06 100644 --- a/tests/test_deterministic.py +++ b/tests/test_deterministic.py @@ -3,7 +3,7 @@ from stable_baselines3.common.noise import NormalActionNoise from stable_baselines3.common.vec_env import VecNormalize -from sb3_contrib import ARS, QRDQN, TQC +from sb3_contrib import ARS, QRDQN, TQC, RecurrentPPO from sb3_contrib.common.vec_env import AsyncEval N_STEPS_TRAINING = 500 @@ -11,7 +11,7 @@ ARS_MULTI = "ars_multi" -@pytest.mark.parametrize("algo", [ARS, QRDQN, TQC, ARS_MULTI]) +@pytest.mark.parametrize("algo", [ARS, QRDQN, TQC, ARS_MULTI, RecurrentPPO]) def test_deterministic_training_common(algo): results = [[], []] rewards = [[], []] @@ -32,9 +32,12 @@ def test_deterministic_training_common(algo): kwargs.update({"learning_starts": 100, "target_update_interval": 100}) elif algo == ARS: kwargs.update({"n_delta": 2}) - + elif algo == RecurrentPPO: + kwargs.update({"policy_kwargs": dict(net_arch=[], enable_critic_lstm=True, lstm_hidden_size=8)}) + kwargs.update({"n_steps": 50, "n_epochs": 4}) + policy_str = "MlpLstmPolicy" if algo == RecurrentPPO else "MlpPolicy" for i in range(2): - model = algo("MlpPolicy", env_id, seed=SEED, **kwargs) + model = algo(policy_str, env_id, seed=SEED, **kwargs) learn_kwargs = {"total_timesteps": N_STEPS_TRAINING} if ars_multi: @@ -46,9 +49,11 @@ def test_deterministic_training_common(algo): model.learn(**learn_kwargs) env = model.get_env() obs = env.reset() + states = None + episode_start = None for _ in range(100): - action, _ = model.predict(obs, deterministic=False) - obs, reward, _, _ = env.step(action) + action, states = model.predict(obs, state=states, episode_start=episode_start, deterministic=False) + obs, reward, episode_start, _ = env.step(action) results[i].append(action) rewards[i].append(reward) assert sum(results[0]) == sum(results[1]), results diff --git a/tests/test_lstm.py b/tests/test_lstm.py index 8ff81f5a..f0ba3e6a 100644 --- a/tests/test_lstm.py +++ b/tests/test_lstm.py @@ -101,7 +101,7 @@ def test_policy_kwargs(policy_kwargs): def test_check(): policy_kwargs = dict(shared_lstm=True, enable_critic_lstm=True) with pytest.raises(AssertionError): - model = RecurrentPPO( + RecurrentPPO( "MlpLstmPolicy", "CartPole-v1", n_steps=16, From 564d42808c5a1d55a32fb7836fadc0c1bbbd1d1c Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Wed, 18 May 2022 23:40:45 +0200 Subject: [PATCH 46/50] Doc fix --- docs/modules/ppo_recurrent.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/modules/ppo_recurrent.rst b/docs/modules/ppo_recurrent.rst index 4c8bd69a..bc819c39 100644 --- a/docs/modules/ppo_recurrent.rst +++ b/docs/modules/ppo_recurrent.rst @@ -1,4 +1,4 @@ -.. _ppo_mask: +.. _ppo_lstm: .. automodule:: sb3_contrib.ppo_recurrent From 408ed247f497c7c9d7e0e44c12fa2f2540dfa5ad Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Sun, 29 May 2022 19:58:11 +0200 Subject: [PATCH 47/50] Fix for new Sphinx version --- docs/conf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/conf.py b/docs/conf.py index 45216198..85c9c16e 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -100,7 +100,7 @@ def __getattr__(cls, name): # # This is also used if you do content translation via gettext catalogs. # Usually you set "language" from the command line for these cases. -language = None +language = "en" # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. From 6acb64aaeba988e65201d9a4b0a0e223ed4967a3 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Sun, 29 May 2022 22:07:28 +0200 Subject: [PATCH 48/50] Fix doc note --- docs/guide/algos.rst | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/guide/algos.rst b/docs/guide/algos.rst index 48b0314a..234e6f8c 100644 --- a/docs/guide/algos.rst +++ b/docs/guide/algos.rst @@ -18,7 +18,8 @@ TRPO ✔️ ✔️ ✔️ ✔️ .. note:: - Non-array spaces such as ``Dict`` or ``Tuple`` are not currently supported by any algorithm. + ``Tuple`` observation spaces are not supported by any environment, + however, single-level ``Dict`` spaces are Actions ``gym.spaces``: From 5fd8be768a84a818caefe1c1978ba4867e954cc3 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Mon, 30 May 2022 02:58:49 +0200 Subject: [PATCH 49/50] Switch to batch first, no more additional swap --- docs/misc/changelog.rst | 4 +++- sb3_contrib/common/recurrent/buffers.py | 24 +++++++++++------------- sb3_contrib/version.txt | 2 +- 3 files changed, 15 insertions(+), 15 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 848a93d4..5554ded5 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,9 +3,11 @@ Changelog ========== -Release 1.5.1a7 (WIP) +Release 1.5.1a8 (WIP) ------------------------------- +**Add RecurrentPPO (aka PPO LSTM)** + Breaking Changes: ^^^^^^^^^^^^^^^^^ - Upgraded to Stable-Baselines3 >= 1.5.1a7 diff --git a/sb3_contrib/common/recurrent/buffers.py b/sb3_contrib/common/recurrent/buffers.py index fe1ad6a1..f5b473ac 100644 --- a/sb3_contrib/common/recurrent/buffers.py +++ b/sb3_contrib/common/recurrent/buffers.py @@ -30,11 +30,11 @@ def pad( :param tensor: Tensor of shape (batch_size, *tensor_shape) :param padding_value: Value used to pad sequence to the same length (zero padding by default) - :return: (n_seq * max_length, *tensor_shape) + :return: (n_seq, max_length, *tensor_shape) """ # Create sequences given start and end seq = [th.tensor(tensor[start : end + 1], device=device) for start, end in zip(seq_start_indices, seq_end_indices)] - return th.nn.utils.rnn.pad_sequence(seq, padding_value=padding_value) + return th.nn.utils.rnn.pad_sequence(seq, batch_first=True, padding_value=padding_value) def pad_and_flatten( @@ -47,7 +47,7 @@ def pad_and_flatten( """ Pad and flatten the sequences of scalar values, while keeping the sequence order. - From (max_length, n_seq, 1) to (n_seq, max_length, 1) -> (n_seq * max_length,) + From (batch_size, 1) to (n_seq, max_length, 1) -> (n_seq * max_length,) :param seq_start_indices: Indices of the transitions that start a sequence :param seq_end_indices: Indices of the transitions that end a sequence @@ -57,7 +57,7 @@ def pad_and_flatten( (zero padding by default) :return: """ - return pad(seq_start_indices, seq_end_indices, device, tensor, padding_value).swapaxes(0, 1).flatten() + return pad(seq_start_indices, seq_end_indices, device, tensor, padding_value).flatten() def create_sequencers( @@ -204,7 +204,7 @@ def _get_samples( n_layers = self.hidden_states_pi.shape[1] # Number of sequences n_seq = len(self.seq_start_indices) - max_length = self.pad(self.actions[batch_inds]).shape[0] + max_length = self.pad(self.actions[batch_inds]).shape[1] padded_batch_size = n_seq * max_length # We retrieve the lstm hidden states that will allow # to properly initialize the LSTM at the beginning of each sequence @@ -222,9 +222,9 @@ def _get_samples( lstm_states_vf = (self.to_torch(lstm_states_vf[0]), self.to_torch(lstm_states_vf[1])) return RecurrentRolloutBufferSamples( - # (max_length, n_seq, obs_dim) to (n_seq, max_length, obs_dim) -> (n_seq * max_length, obs_dim) - observations=self.pad(self.observations[batch_inds]).swapaxes(0, 1).reshape((padded_batch_size,) + self.obs_shape), - actions=self.pad(self.actions[batch_inds]).swapaxes(0, 1).reshape((padded_batch_size,) + self.actions.shape[1:]), + # (batch_size, obs_dim) -> (n_seq, max_length, obs_dim) -> (n_seq * max_length, obs_dim) + observations=self.pad(self.observations[batch_inds]).reshape((padded_batch_size,) + self.obs_shape), + actions=self.pad(self.actions[batch_inds]).reshape((padded_batch_size,) + self.actions.shape[1:]), old_values=self.pad_and_flatten(self.values[batch_inds]), old_log_prob=self.pad_and_flatten(self.log_probs[batch_inds]), advantages=self.pad_and_flatten(self.advantages[batch_inds]), @@ -357,7 +357,7 @@ def _get_samples( n_layers = self.hidden_states_pi.shape[1] n_seq = len(self.seq_start_indices) - max_length = self.pad(self.actions[batch_inds]).shape[0] + max_length = self.pad(self.actions[batch_inds]).shape[1] padded_batch_size = n_seq * max_length # We retrieve the lstm hidden states that will allow # to properly initialize the LSTM at the beginning of each sequence @@ -375,13 +375,11 @@ def _get_samples( lstm_states_vf = (self.to_torch(lstm_states_vf[0]), self.to_torch(lstm_states_vf[1])) observations = {key: self.pad(obs[batch_inds]) for (key, obs) in self.observations.items()} - observations = { - key: obs.swapaxes(0, 1).reshape((padded_batch_size,) + self.obs_shape[key]) for (key, obs) in observations.items() - } + observations = {key: obs.reshape((padded_batch_size,) + self.obs_shape[key]) for (key, obs) in observations.items()} return RecurrentDictRolloutBufferSamples( observations=observations, - actions=self.pad(self.actions[batch_inds]).swapaxes(0, 1).reshape((padded_batch_size,) + self.actions.shape[1:]), + actions=self.pad(self.actions[batch_inds]).reshape((padded_batch_size,) + self.actions.shape[1:]), old_values=self.pad_and_flatten(self.values[batch_inds]), old_log_prob=self.pad_and_flatten(self.log_probs[batch_inds]), advantages=self.pad_and_flatten(self.advantages[batch_inds]), diff --git a/sb3_contrib/version.txt b/sb3_contrib/version.txt index e39732bd..511e75b2 100644 --- a/sb3_contrib/version.txt +++ b/sb3_contrib/version.txt @@ -1 +1 @@ -1.5.1a7 +1.5.1a8 From 7a1d3e8e6d79823dc6cf74dd8cfc295c50a252bb Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Mon, 30 May 2022 04:05:25 +0200 Subject: [PATCH 50/50] Add comments and mask entropy loss --- sb3_contrib/common/recurrent/buffers.py | 24 ++++++++-------------- sb3_contrib/common/recurrent/policies.py | 18 ++++++++++------ sb3_contrib/ppo_recurrent/ppo_recurrent.py | 14 ++++++------- 3 files changed, 28 insertions(+), 28 deletions(-) diff --git a/sb3_contrib/common/recurrent/buffers.py b/sb3_contrib/common/recurrent/buffers.py index f5b473ac..88ff4254 100644 --- a/sb3_contrib/common/recurrent/buffers.py +++ b/sb3_contrib/common/recurrent/buffers.py @@ -55,7 +55,7 @@ def pad_and_flatten( :param tensor: Tensor of shape (max_length, n_seq, 1) :param padding_value: Value used to pad sequence to the same length (zero padding by default) - :return: + :return: (n_seq * max_length,) aka (padded_batch_size,) """ return pad(seq_start_indices, seq_end_indices, device, tensor, padding_value).flatten() @@ -73,7 +73,9 @@ def create_sequencers( :param env_change: Indices where the data collected come from a different env (when using multiple env for data collection) :param device: PyTorch device - :return: + :return: Indices of the transitions that start a sequence, + pad and pad_and_flatten utilities tailored for this batch + (sequence starts and ends indices are fixed) """ # Create sequence if env changes too seq_start = np.logical_or(episode_starts, env_change).flatten() @@ -94,12 +96,13 @@ def create_sequencers( class RecurrentRolloutBuffer(RolloutBuffer): """ - Rollout buffer that also stores the invalid action masks associated with each observation. + Rollout buffer that also stores the LSTM cell and hidden states. :param buffer_size: Max number of element in the buffer :param observation_space: Observation space :param action_space: Action space :param hidden_state_shape: Shape of the buffer that will collect lstm states + (n_steps, lstm.num_layers, n_envs, lstm.hidden_size) :param device: PyTorch device :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator Equivalent to classic advantage when set to 1. @@ -119,7 +122,6 @@ def __init__( n_envs: int = 1, ): self.hidden_state_shape = hidden_state_shape - self.initial_lstm_states = None self.seq_start_indices, self.seq_end_indices = None, None super().__init__(buffer_size, observation_space, action_space, device, gae_lambda, gamma, n_envs) @@ -151,6 +153,9 @@ def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentRolloutBuf for tensor in ["hidden_states_pi", "cell_states_pi", "hidden_states_vf", "cell_states_vf"]: self.__dict__[tensor] = self.__dict__[tensor].swapaxes(1, 2) + # flatten but keep the sequence order + # 1. (n_steps, n_envs, *tensor_shape) -> (n_envs, n_steps, *tensor_shape) + # 2. (n_envs, n_steps, *tensor_shape) -> (n_envs * n_steps, *tensor_shape) for tensor in [ "observations", "actions", @@ -240,16 +245,6 @@ class RecurrentDictRolloutBuffer(DictRolloutBuffer): Dict Rollout buffer used in on-policy algorithms like A2C/PPO. Extends the RecurrentRolloutBuffer to use dictionary observations - It corresponds to ``buffer_size`` transitions collected - using the current policy. - This experience will be discarded after the policy update. - In order to use PPO objective, we also store the current value of each state - and the log probability of each taken action. - - The term rollout here refers to the model-free notion and should not - be used with the concept of rollout used in model-based RL or planning. - Hence, it is only involved in policy and value function training but not action selection. - :param buffer_size: Max number of element in the buffer :param observation_space: Observation space :param action_space: Action space @@ -273,7 +268,6 @@ def __init__( n_envs: int = 1, ): self.hidden_state_shape = hidden_state_shape - self.initial_lstm_states = None self.seq_start_indices, self.seq_end_indices = None, None super().__init__(buffer_size, observation_space, action_space, device, gae_lambda, gamma, n_envs=n_envs) diff --git a/sb3_contrib/common/recurrent/policies.py b/sb3_contrib/common/recurrent/policies.py index 3ce577c7..16f1c200 100644 --- a/sb3_contrib/common/recurrent/policies.py +++ b/sb3_contrib/common/recurrent/policies.py @@ -23,6 +23,8 @@ class RecurrentActorCriticPolicy(ActorCriticPolicy): """ Recurrent policy class for actor-critic algorithms (has both policy and value prediction). To be used with A2C, PPO and the likes. + It assumes that both the actor and the critic LSTM + have the same architecture. :param observation_space: Observation space :param action_space: Action space @@ -173,11 +175,14 @@ def _process_sequence( :return: LSTM output and updated LSTM states. """ # LSTM logic - # (sequence length, n_envs, features dim) (batch size = n envs) - n_envs = lstm_states[0].shape[1] + # (sequence length, batch size, features dim) + # (batch size = n_envs for data collection or n_seq when doing gradient update) + n_seq = lstm_states[0].shape[1] # Batch to sequence - features_sequence = features.reshape((n_envs, -1, lstm.input_size)).swapaxes(0, 1) - episode_starts = episode_starts.reshape((n_envs, -1)).swapaxes(0, 1) + # (padded batch size, features_dim) -> (n_seq, max length, features_dim) -> (max length, n_seq, features_dim) + # note: max length (max sequence length) is always 1 during data collection + features_sequence = features.reshape((n_seq, -1, lstm.input_size)).swapaxes(0, 1) + episode_starts = episode_starts.reshape((n_seq, -1)).swapaxes(0, 1) # If we don't have to reset the state in the middle of a sequence # we can avoid the for loop, which speeds up things @@ -193,12 +198,13 @@ def _process_sequence( features.unsqueeze(dim=0), ( # Reset the states at the beginning of a new episode - (1.0 - episode_start).view(1, n_envs, 1) * lstm_states[0], - (1.0 - episode_start).view(1, n_envs, 1) * lstm_states[1], + (1.0 - episode_start).view(1, n_seq, 1) * lstm_states[0], + (1.0 - episode_start).view(1, n_seq, 1) * lstm_states[1], ), ) lstm_output += [hidden] # Sequence to batch + # (sequence length, n_seq, lstm_out_dim) -> (batch_size, lstm_out_dim) lstm_output = th.flatten(th.cat(lstm_output).transpose(0, 1), start_dim=0, end_dim=1) return lstm_output, lstm_states diff --git a/sb3_contrib/ppo_recurrent/ppo_recurrent.py b/sb3_contrib/ppo_recurrent/ppo_recurrent.py index 7d5fc53e..f0920f9d 100644 --- a/sb3_contrib/ppo_recurrent/ppo_recurrent.py +++ b/sb3_contrib/ppo_recurrent/ppo_recurrent.py @@ -153,6 +153,8 @@ def _setup_model(self) -> None: ) self.policy = self.policy.to(self.device) + # We assume that LSTM for the actor and the critic + # have the same architecture lstm = self.policy.lstm_actor if not isinstance(self.policy, RecurrentActorCriticPolicy): @@ -188,7 +190,7 @@ def _setup_model(self) -> None: self.clip_range = get_schedule_fn(self.clip_range) if self.clip_range_vf is not None: if isinstance(self.clip_range_vf, (float, int)): - assert self.clip_range_vf > 0, "`clip_range_vf` must be positive, " "pass `None` to deactivate vf clipping" + assert self.clip_range_vf > 0, "`clip_range_vf` must be positive, pass `None` to deactivate vf clipping" self.clip_range_vf = get_schedule_fn(self.clip_range_vf) @@ -201,7 +203,7 @@ def _setup_learn( n_eval_episodes: int = 5, log_path: Optional[str] = None, reset_num_timesteps: bool = True, - tb_log_name: str = "run", + tb_log_name: str = "RecurrentPPO", ) -> Tuple[int, BaseCallback]: """ Initialize different variables needed for training. @@ -251,7 +253,7 @@ def collect_rollouts( """ assert isinstance( rollout_buffer, (RecurrentRolloutBuffer, RecurrentDictRolloutBuffer) - ), "RolloutBuffer doesn't support recurrent policy" + ), f"{rollout_buffer} doesn't support recurrent policy" assert self._last_obs is not None, "No previous observation was provided" # Switch to eval mode (this affects batch norm / dropout) @@ -265,7 +267,6 @@ def collect_rollouts( callback.on_rollout_start() - rollout_buffer.initial_lstm_states = deepcopy(self._last_lstm_states) lstm_states = deepcopy(self._last_lstm_states) while n_steps < n_rollout_steps: @@ -366,7 +367,6 @@ def train(self) -> None: clip_fractions = [] continue_training = True - # self.policy.features_extractor.debug = True # train for n_epochs epochs for epoch in range(self.n_epochs): @@ -429,9 +429,9 @@ def train(self) -> None: # Entropy loss favor exploration if entropy is None: # Approximate entropy when no analytical form - entropy_loss = -th.mean(-log_prob) + entropy_loss = -th.mean(-(log_prob * rollout_data.mask)) else: - entropy_loss = -th.mean(entropy) + entropy_loss = -th.mean(entropy * rollout_data.mask) entropy_losses.append(entropy_loss.item())