diff --git a/.gitignore b/.gitignore index 1aca03f3..0ba606dc 100644 --- a/.gitignore +++ b/.gitignore @@ -47,3 +47,5 @@ src *.prof MUJOCO_LOG.TXT + +temp/ diff --git a/sb3_contrib/common/recurrent/buffers.py b/sb3_contrib/common/recurrent/buffers.py index a9f1d4ec..ccef3553 100644 --- a/sb3_contrib/common/recurrent/buffers.py +++ b/sb3_contrib/common/recurrent/buffers.py @@ -1,15 +1,19 @@ from functools import partial -from typing import Callable, Generator, Optional, Tuple, Union +from typing import Callable, Generator, List, Optional, Tuple, Union import numpy as np import torch as th from gymnasium import spaces from stable_baselines3.common.buffers import DictRolloutBuffer, RolloutBuffer from stable_baselines3.common.vec_env import VecNormalize +from torch.nn.utils.rnn import pad_sequence +from torch.utils.data.sampler import BatchSampler, SubsetRandomSampler from sb3_contrib.common.recurrent.type_aliases import ( RecurrentDictRolloutBufferSamples, + RecurrentDictRolloutBufferSequenceSamples, RecurrentRolloutBufferSamples, + RecurrentRolloutBufferSequenceSamples, RNNStates, ) @@ -94,6 +98,30 @@ def create_sequencers( return seq_start_indices, local_pad, local_pad_and_flatten +def create_sequence_slicer( + episode_start_indices: np.ndarray, device: Union[th.device, str] +) -> Callable[[np.ndarray, List[str]], th.Tensor]: + def create_sequence_minibatch(tensor: np.ndarray, seq_indices: List[str]) -> th.Tensor: + """ + Create minibatch of whole sequence. + + :param tensor: Tensor that will be sliced (e.g. observations, rewards) + :param seq_indices: Sequences to be used. + :return: (max_sequence_length, batch_size=n_seq, features_size) + """ + return pad_sequence( + [ + th.tensor( + tensor[episode_start_indices[i] : episode_start_indices[i + 1]], + device=device, + ) + for i in seq_indices + ] + ) + + return create_sequence_minibatch + + class RecurrentRolloutBuffer(RolloutBuffer): """ Rollout buffer that also stores the LSTM cell and hidden states. @@ -382,3 +410,161 @@ def _get_samples( episode_starts=self.pad_and_flatten(self.episode_starts[batch_inds]), mask=self.pad_and_flatten(np.ones_like(self.returns[batch_inds])), ) + + +class RecurrentSequenceRolloutBuffer(RecurrentRolloutBuffer): + """ + Sequence Rollout buffer used in on-policy algorithms like A2C/PPO. + Overrides the RecurrentRolloutBuffer to yield 3d batches of whole sequences + + :param buffer_size: Max number of element in the buffer + :param observation_space: Observation space + :param action_space: Action space + :param hidden_state_shape: Shape of the buffer that will collect lstm states + :param device: PyTorch device + :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator + Equivalent to classic advantage when set to 1. + :param gamma: Discount factor + :param n_envs: Number of parallel environments + """ + + def __init__( + self, + buffer_size: int, + observation_space: spaces.Space, + action_space: spaces.Space, + hidden_state_shape: Tuple[int, int, int, int], + device: Union[th.device, str] = "auto", + gae_lambda: float = 1, + gamma: float = 0.99, + n_envs: int = 1, + ): + self.hidden_state_shape = hidden_state_shape + self.seq_start_indices, self.seq_end_indices = None, None + super().__init__( + buffer_size, observation_space, action_space, hidden_state_shape, device, gae_lambda, gamma, n_envs=n_envs + ) + + def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentRolloutBufferSequenceSamples, None, None]: + assert self.full, "Rollout buffer must be full before sampling from it" + # Prepare the data + if not self.generator_ready: + self.episode_starts[0, :] = 1 + for tensor in [ + "observations", + "actions", + "values", + "log_probs", + "advantages", + "returns", + "episode_starts", + ]: + self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor]) + + self.episode_start_indices = np.where(self.episode_starts == 1)[0] + self.generator_ready = True + + random_indices = SubsetRandomSampler(range(len(self.episode_start_indices))) + # Do not drop last batch so we are sure we sample at least one sequence + # TODO: allow to change that parameter + batch_sampler = BatchSampler(random_indices, batch_size, drop_last=False) + # add a dummy index to make the code below simpler + episode_start_indices = np.concatenate([self.episode_start_indices, np.array([len(self.episode_starts)])]) + + create_minibatch = create_sequence_slicer(episode_start_indices, self.device) + + # yields batches of whole sequences, shape: (max_sequence_length, batch_size=n_seq, features_size)) + for indices in batch_sampler: + returns_batch = create_minibatch(self.returns, indices) + masks_batch = pad_sequence([th.ones_like(returns) for returns in th.swapaxes(returns_batch, 0, 1)]) + + yield RecurrentRolloutBufferSequenceSamples( + observations=create_minibatch(self.observations, indices), + actions=create_minibatch(self.actions, indices), + old_values=create_minibatch(self.values, indices), + old_log_prob=create_minibatch(self.log_probs, indices), + advantages=create_minibatch(self.advantages, indices), + returns=returns_batch, + mask=masks_batch, + ) + + +class RecurrentSequenceDictRolloutBuffer(RecurrentDictRolloutBuffer): + """ + Sequence Dict Rollout buffer used in on-policy algorithms like A2C/PPO. + Overrides the DictRecurrentRolloutBuffer to yield 3d batches of whole sequences + + :param buffer_size: Max number of element in the buffer + :param observation_space: Observation space + :param action_space: Action space + :param hidden_state_shape: Shape of the buffer that will collect lstm states + :param device: PyTorch device + :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator + Equivalent to classic advantage when set to 1. + :param gamma: Discount factor + :param n_envs: Number of parallel environments + """ + + def __init__( + self, + buffer_size: int, + observation_space: spaces.Space, + action_space: spaces.Space, + hidden_state_shape: Tuple[int, int, int, int], + device: Union[th.device, str] = "auto", + gae_lambda: float = 1, + gamma: float = 0.99, + n_envs: int = 1, + ): + self.hidden_state_shape = hidden_state_shape + self.seq_start_indices, self.seq_end_indices = None, None + super().__init__( + buffer_size, observation_space, action_space, hidden_state_shape, device, gae_lambda, gamma, n_envs=n_envs + ) + + def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentDictRolloutBufferSequenceSamples, None, None]: + assert self.full, "Rollout buffer must be full before sampling from it" + # Prepare the data + if not self.generator_ready: + self.episode_starts[0, :] = 1 + for key, obs in self.observations.items(): + self.observations[key] = self.swap_and_flatten(obs) + + for tensor in [ + "actions", + "values", + "log_probs", + "advantages", + "returns", + "episode_starts", + ]: + self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor]) + + self.episode_start_indices = np.where(self.episode_starts == 1)[0] + self.generator_ready = True + + random_indices = SubsetRandomSampler(range(len(self.episode_start_indices))) + # drop last batch to prevent extremely small batches causing spurious updates + batch_sampler = BatchSampler(random_indices, batch_size, drop_last=True) + # add a dummy index to make the code below simpler + episode_start_indices = np.concatenate([self.episode_start_indices, np.array([len(self.episode_starts)])]) + + create_minibatch = create_sequence_slicer(episode_start_indices, self.device) + + # yields batches of whole sequences, shape: (sequence_length, batch_size=n_seq, features_size) + for indices in batch_sampler: + obs_batch = {} + for key in self.observations: + obs_batch[key] = create_minibatch(self.observations[key], indices) + returns_batch = create_minibatch(self.returns, indices) + masks_batch = pad_sequence([th.ones_like(returns) for returns in th.swapaxes(returns_batch, 0, 1)]) + + yield RecurrentDictRolloutBufferSequenceSamples( + observations=obs_batch, + actions=create_minibatch(self.actions, indices), + old_values=create_minibatch(self.values, indices), + old_log_prob=create_minibatch(self.log_probs, indices), + advantages=create_minibatch(self.advantages, indices), + returns=returns_batch, + mask=masks_batch, + ) diff --git a/sb3_contrib/common/recurrent/policies.py b/sb3_contrib/common/recurrent/policies.py index 3fa59407..8d6cf354 100644 --- a/sb3_contrib/common/recurrent/policies.py +++ b/sb3_contrib/common/recurrent/policies.py @@ -344,6 +344,54 @@ def evaluate_actions( values = self.value_net(latent_vf) return values, log_prob, distribution.entropy() + def evaluate_actions_whole_sequence( + self, + obs: th.Tensor, + actions: th.Tensor, + ) -> Tuple[th.Tensor, th.Tensor, th.Tensor]: + """ + Evaluate actions of batches of whole sequences according to the current policy, + given the observations. + + :param obs: Observation. + :param actions: + :param lstm_states: The last hidden and memory states for the LSTM. + :param episode_starts: Whether the observations correspond to new episodes + or not (we reset the lstm states in that case). + :return: estimated value, log likelihood of taking those actions + and entropy of the action distribution. + """ + # Preprocess the observation if needed + + # temporary fix to disable the flattening that stable_baselines3 feature extractors do by default + # flattening will turn the sequences in the batch into 1 long sequence without proper resetting of lstm hidden states + if self.features_extractor_class == FlattenExtractor: + features = obs + else: + features = self.extract_features(obs) + latent_pi, _ = self.lstm_actor(features) + + if self.lstm_critic is not None: + latent_vf, _ = self.lstm_critic(features) + elif self.shared_lstm: + latent_vf = latent_pi.detach() + else: + latent_vf = self.critic(features) + + latent_pi = self.mlp_extractor.forward_actor(latent_pi) + latent_vf = self.mlp_extractor.forward_critic(latent_vf) + + values = self.value_net(latent_vf) + + distribution = self._get_action_dist_from_latent(latent_pi) + log_prob = distribution.distribution.log_prob(actions).sum(dim=-1) + log_prob = log_prob.reshape((*log_prob.shape, 1)) + + entropy = distribution.distribution.entropy().sum(dim=-1) + entropy = entropy.reshape((*entropy.shape, 1)) + + return values, log_prob, entropy + def _predict( self, observation: th.Tensor, diff --git a/sb3_contrib/common/recurrent/type_aliases.py b/sb3_contrib/common/recurrent/type_aliases.py index 21ac0e0d..33fffd59 100644 --- a/sb3_contrib/common/recurrent/type_aliases.py +++ b/sb3_contrib/common/recurrent/type_aliases.py @@ -31,3 +31,23 @@ class RecurrentDictRolloutBufferSamples(NamedTuple): lstm_states: RNNStates episode_starts: th.Tensor mask: th.Tensor + + +class RecurrentRolloutBufferSequenceSamples(NamedTuple): + observations: th.Tensor + actions: th.Tensor + old_values: th.Tensor + old_log_prob: th.Tensor + advantages: th.Tensor + returns: th.Tensor + mask: th.Tensor + + +class RecurrentDictRolloutBufferSequenceSamples(NamedTuple): + observations: TensorDict + actions: th.Tensor + old_values: th.Tensor + old_log_prob: th.Tensor + advantages: th.Tensor + returns: th.Tensor + mask: th.Tensor diff --git a/sb3_contrib/ppo_recurrent/ppo_recurrent.py b/sb3_contrib/ppo_recurrent/ppo_recurrent.py index 39fd9416..cfd542f4 100644 --- a/sb3_contrib/ppo_recurrent/ppo_recurrent.py +++ b/sb3_contrib/ppo_recurrent/ppo_recurrent.py @@ -14,7 +14,12 @@ from stable_baselines3.common.utils import explained_variance, get_schedule_fn, obs_as_tensor, safe_mean from stable_baselines3.common.vec_env import VecEnv -from sb3_contrib.common.recurrent.buffers import RecurrentDictRolloutBuffer, RecurrentRolloutBuffer +from sb3_contrib.common.recurrent.buffers import ( + RecurrentDictRolloutBuffer, + RecurrentRolloutBuffer, + RecurrentSequenceDictRolloutBuffer, + RecurrentSequenceRolloutBuffer, +) from sb3_contrib.common.recurrent.policies import RecurrentActorCriticPolicy from sb3_contrib.common.recurrent.type_aliases import RNNStates from sb3_contrib.ppo_recurrent.policies import CnnLstmPolicy, MlpLstmPolicy, MultiInputLstmPolicy @@ -80,6 +85,7 @@ def __init__( learning_rate: Union[float, Schedule] = 3e-4, n_steps: int = 128, batch_size: Optional[int] = 128, + whole_sequences: bool = False, n_epochs: int = 10, gamma: float = 0.99, gae_lambda: float = 0.95, @@ -128,6 +134,7 @@ def __init__( ) self.batch_size = batch_size + self.whole_sequences = whole_sequences self.n_epochs = n_epochs self.clip_range = clip_range self.clip_range_vf = clip_range_vf @@ -142,7 +149,17 @@ def _setup_model(self) -> None: self._setup_lr_schedule() self.set_random_seed(self.seed) - buffer_cls = RecurrentDictRolloutBuffer if isinstance(self.observation_space, spaces.Dict) else RecurrentRolloutBuffer + # 3d batches of whole sequences or 2d batches of constant size + if self.whole_sequences: + buffer_cls = ( + RecurrentSequenceDictRolloutBuffer + if isinstance(self.observation_space, spaces.Dict) + else RecurrentSequenceRolloutBuffer + ) + else: + buffer_cls = ( + RecurrentDictRolloutBuffer if isinstance(self.observation_space, spaces.Dict) else RecurrentRolloutBuffer + ) self.policy = self.policy_class( self.observation_space, @@ -215,7 +232,13 @@ def collect_rollouts( collected, False if callback terminated rollout prematurely. """ assert isinstance( - rollout_buffer, (RecurrentRolloutBuffer, RecurrentDictRolloutBuffer) + rollout_buffer, + ( + RecurrentRolloutBuffer, + RecurrentDictRolloutBuffer, + RecurrentSequenceRolloutBuffer, + RecurrentSequenceDictRolloutBuffer, + ), ), f"{rollout_buffer} doesn't support recurrent policy" assert self._last_obs is not None, "No previous observation was provided" @@ -339,7 +362,9 @@ def train(self) -> None: actions = rollout_data.actions if isinstance(self.action_space, spaces.Discrete): # Convert discrete action from float to long - actions = rollout_data.actions.long().flatten() + actions = rollout_data.actions.long() + if not self.whole_sequences: + actions = actions.flatten() # Convert mask from float to bool mask = rollout_data.mask > 1e-8 @@ -348,17 +373,24 @@ def train(self) -> None: if self.use_sde: self.policy.reset_noise(self.batch_size) - values, log_prob, entropy = self.policy.evaluate_actions( - rollout_data.observations, - actions, - rollout_data.lstm_states, - rollout_data.episode_starts, - ) + if self.whole_sequences: + values, log_prob, entropy = self.policy.evaluate_actions_whole_sequence( + rollout_data.observations, + actions, + ) + else: + values, log_prob, entropy = self.policy.evaluate_actions( + rollout_data.observations, + actions, + rollout_data.lstm_states, + rollout_data.episode_starts, + ) + values = values.flatten() - values = values.flatten() # Normalize advantage advantages = rollout_data.advantages - if self.normalize_advantage: + # Normalization does not make sense if mini batchsize == 1, see GH issue #325 + if self.normalize_advantage and len(advantages) > 1: advantages = (advantages - advantages[mask].mean()) / (advantages[mask].std() + 1e-8) # ratio between old and new policy, should be one at the first iteration diff --git a/whole_sequence_speed_test.py b/whole_sequence_speed_test.py new file mode 100644 index 00000000..4778c084 --- /dev/null +++ b/whole_sequence_speed_test.py @@ -0,0 +1,167 @@ +import gym +import numpy as np +import torch.nn as nn +from sb3_contrib import RecurrentPPO +from stable_baselines3.common.vec_env import SubprocVecEnv, VecMonitor, VecNormalize + +class MaskVelocityWrapper(gym.ObservationWrapper): + """ + Gym environment observation wrapper used to mask velocity terms in + observations. The intention is the make the MDP partially observable. + Adapted from https://github.com/LiuWenlin595/FinalProject. + :param env: Gym environment + """ + + # Supported envs + velocity_indices = { + "CartPole-v1": np.array([1, 3]), + "MountainCar-v0": np.array([1]), + "MountainCarContinuous-v0": np.array([1]), + "Pendulum-v1": np.array([2]), + "LunarLander-v2": np.array([2, 3, 5]), + "LunarLanderContinuous-v2": np.array([2, 3, 5]), + } + + def __init__(self, env: gym.Env): + super().__init__(env) + + env_id: str = env.unwrapped.spec.id + # By default no masking + self.mask = np.ones_like((env.observation_space.sample())) + try: + # Mask velocity + self.mask[self.velocity_indices[env_id]] = 0.0 + except KeyError: + raise NotImplementedError(f"Velocity masking not implemented for {env_id}") + + def observation(self, observation: np.ndarray) -> np.ndarray: + return observation * self.mask + + +def make_env(mask_vel=False, **kwargs): + def _init(): + env = gym.make(**kwargs) + if mask_vel: + env = MaskVelocityWrapper(env) + return env + return _init + + +def get_vectorized_envs(n_cpus, **kwargs): + envs_no_log = SubprocVecEnv([make_env(**kwargs) for _ in range(n_cpus)]) + envs = VecNormalize(VecMonitor(envs_no_log)) + return envs + + +if __name__ == "__main__": + ############################################################ + # BipedalWalker-v3 + ############################################################ + + n_cpus = 32 + envs = get_vectorized_envs(n_cpus=n_cpus, id="BipedalWalker-v3") + model = RecurrentPPO( + "MlpLstmPolicy", + envs, + n_steps=256, + batch_size=256, + gae_lambda=0.95, + gamma=0.999, + whole_sequences=False, + n_epochs=10, + ent_coef=0.0, + learning_rate=0.0003, + clip_range=0.18, + policy_kwargs={ + "ortho_init": False, + "activation_fn": nn.ReLU, + "lstm_hidden_size": 64, + "enable_critic_lstm": True, + "net_arch": [dict(pi=[64], vf=[64])] + }, + tensorboard_log="temp/", + verbose=1 + ) + model.learn(5e6, tb_log_name="BipedalWalker-v3_sb3_standard") + + + model = RecurrentPPO( + "MlpLstmPolicy", + envs, + n_steps=256, + batch_size=4, + gae_lambda=0.95, + gamma=0.999, + whole_sequences=True, # This sets use of whole sequence batching + n_epochs=10, + ent_coef=0.0, + learning_rate=0.0003, + clip_range=0.18, + policy_kwargs={ + "ortho_init": False, + "activation_fn": nn.ReLU, + "lstm_hidden_size": 64, + "enable_critic_lstm": True, + "net_arch": [dict(pi=[64], vf=[64])] + }, + tensorboard_log="temp/", + verbose=1 + ) + model.learn(5e6, tb_log_name="BipedalWalker-v3_whole_sequences") + + + ############################################################ + # PendulumNoVel + ############################################################ + + n_cpus = 4 + envs = get_vectorized_envs(n_cpus=n_cpus, id="Pendulum-v1", mask_vel=True) + + model = RecurrentPPO( + "MlpLstmPolicy", + envs, + n_steps=1024, + # batch_size=256, + gae_lambda=0.95, + gamma=0.9, + whole_sequences=False, + n_epochs=10, + ent_coef=0.0, + learning_rate=0.001, + clip_range=0.2, + policy_kwargs={ + "ortho_init": False, + "activation_fn": nn.ReLU, + "lstm_hidden_size": 64, + "enable_critic_lstm": True, + "net_arch": [dict(pi=[64], vf=[64])] + }, + tensorboard_log="temp/", + verbose=1 + ) + model.learn(2e5, tb_log_name="PendulumNoVel-v1_sb3_standard") + + for batch_size in [2, 4, 8]: + model = RecurrentPPO( + "MlpLstmPolicy", + envs, + n_steps=1024, + batch_size=batch_size, + gae_lambda=0.95, + gamma=0.9, + whole_sequences=True, # This sets use of whole sequence batching + n_epochs=10, + ent_coef=0.0, + learning_rate=0.001, + clip_range=0.2, + policy_kwargs={ + "ortho_init": False, + "activation_fn": nn.ReLU, + "lstm_hidden_size": 64, + "enable_critic_lstm": True, + "net_arch": [dict(pi=[64], vf=[64])] + }, + tensorboard_log="temp/", + verbose=1 + ) + model.learn(2e5, tb_log_name=f"PendulumNoVel-v1_whole_sequences_batch_size{batch_size}") \ No newline at end of file