From 6467c79c75b194c863b57c7bdb0614836e61ff41 Mon Sep 17 00:00:00 2001 From: b-x-ubuntu Date: Sat, 12 Nov 2022 22:05:04 +0100 Subject: [PATCH 01/17] added whole sequence batching functionality to PPORecurrent --- sb3_contrib/common/recurrent/buffers.py | 88 +++++++++++ sb3_contrib/common/recurrent/policies.py | 37 +++++ sb3_contrib/common/recurrent/type_aliases.py | 9 ++ sb3_contrib/ppo_recurrent/ppo_recurrent.py | 152 ++++++++++++++++++- 4 files changed, 281 insertions(+), 5 deletions(-) diff --git a/sb3_contrib/common/recurrent/buffers.py b/sb3_contrib/common/recurrent/buffers.py index 53856681..5c81840f 100644 --- a/sb3_contrib/common/recurrent/buffers.py +++ b/sb3_contrib/common/recurrent/buffers.py @@ -3,6 +3,8 @@ import numpy as np import torch as th +from torch.utils.data.sampler import BatchSampler, SubsetRandomSampler +from torch.nn.utils.rnn import pad_sequence from gym import spaces from stable_baselines3.common.buffers import DictRolloutBuffer, RolloutBuffer from stable_baselines3.common.vec_env import VecNormalize @@ -11,6 +13,7 @@ RecurrentDictRolloutBufferSamples, RecurrentRolloutBufferSamples, RNNStates, + RecurrentDictRolloutBufferSequenceSamples, ) @@ -382,3 +385,88 @@ def _get_samples( episode_starts=self.pad_and_flatten(self.episode_starts[batch_inds]), mask=self.pad_and_flatten(np.ones_like(self.returns[batch_inds])), ) + + +class RecurrentSequenceDictRolloutBuffer(RecurrentDictRolloutBuffer): + """ + Sequence Dict Rollout buffer used in on-policy algorithms like A2C/PPO. + Overrides the DictRecurrentRolloutBuffer to yield batches of whole sequences + + :param buffer_size: Max number of element in the buffer + :param observation_space: Observation space + :param action_space: Action space + :param hidden_state_shape: Shape of the buffer that will collect lstm states + :param device: PyTorch device + :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator + Equivalent to classic advantage when set to 1. + :param gamma: Discount factor + :param n_envs: Number of parallel environments + """ + + def __init__( + self, + buffer_size: int, + observation_space: spaces.Space, + action_space: spaces.Space, + hidden_state_shape: Tuple[int, int, int, int], + device: Union[th.device, str] = "auto", + gae_lambda: float = 1, + gamma: float = 0.99, + n_envs: int = 1, + ): + self.hidden_state_shape = hidden_state_shape + self.seq_start_indices, self.seq_end_indices = None, None + super().__init__(buffer_size, observation_space, action_space, hidden_state_shape, device, gae_lambda, gamma, n_envs=n_envs) + + def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentDictRolloutBufferSamples, None, None]: + assert self.full, "Rollout buffer must be full before sampling from it" + # Prepare the data + if not self.generator_ready: + # hidden_state_shape = (self.n_steps, lstm.num_layers, self.n_envs, lstm.hidden_size) + # swap first to (self.n_steps, self.n_envs, lstm.num_layers, lstm.hidden_size) + for tensor in ["hidden_states_pi", "cell_states_pi", "hidden_states_vf", "cell_states_vf"]: + self.__dict__[tensor] = self.__dict__[tensor].swapaxes(1, 2) + + for key, obs in self.observations.items(): + self.observations[key] = self.swap_and_flatten(obs) + + for tensor in [ + "actions", + "values", + "log_probs", + "advantages", + "returns", + "hidden_states_pi", + "cell_states_pi", + "hidden_states_vf", + "cell_states_vf", + "episode_starts", + ]: + self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor]) + + self.episode_start_indices = np.where(self.episode_starts == 1)[0] + self.generator_ready = True + + random_indices = SubsetRandomSampler(range(len(self.episode_start_indices)-1)) # dropping the last one to make indexing the np arrays much simpler + batch_sampler = BatchSampler(random_indices, batch_size, drop_last=False) + + # yields batches of whole sequences, shape: (batch_size, sequence_length, data_length) + for indices in batch_sampler: + obs_batch = {} + for key in self.observations: + obs_batch[key] = pad_sequence([th.Tensor(self.observations[key][self.episode_start_indices[i]:self.episode_start_indices[i+1]]) for i in indices], batch_first=True) + + actions_batch = pad_sequence([th.Tensor(self.actions[self.episode_start_indices[i]:self.episode_start_indices[i+1]]) for i in indices], batch_first=True) + old_values_batch = pad_sequence([th.Tensor(self.values[self.episode_start_indices[i]:self.episode_start_indices[i+1]]) for i in indices], batch_first=True) + old_log_probs_batch = pad_sequence([th.Tensor(self.log_probs[self.episode_start_indices[i]:self.episode_start_indices[i+1]]) for i in indices], batch_first=True) + advantages_batch = pad_sequence([th.Tensor(self.advantages[self.episode_start_indices[i]:self.episode_start_indices[i+1]]) for i in indices], batch_first=True) + returns_batch = pad_sequence([th.Tensor(self.returns[self.episode_start_indices[i]:self.episode_start_indices[i+1]]) for i in indices], batch_first=True) + + yield RecurrentDictRolloutBufferSequenceSamples( + observations=obs_batch, + actions=actions_batch, + old_values=old_values_batch, + old_log_prob=old_log_probs_batch, + advantages=advantages_batch, + returns=returns_batch + ) \ No newline at end of file diff --git a/sb3_contrib/common/recurrent/policies.py b/sb3_contrib/common/recurrent/policies.py index 1ba52734..690ec8fb 100644 --- a/sb3_contrib/common/recurrent/policies.py +++ b/sb3_contrib/common/recurrent/policies.py @@ -336,6 +336,43 @@ def evaluate_actions( values = self.value_net(latent_vf) return values, log_prob, distribution.entropy() + def evaluate_actions_whole_sequence( + self, + obs: th.Tensor, + actions: th.Tensor, + ) -> Tuple[th.Tensor, th.Tensor, th.Tensor]: + """ + Evaluate actions of batches of whole sequences according to the current policy, + given the observations. + + :param obs: Observation. + :param actions: + :param lstm_states: The last hidden and memory states for the LSTM. + :param episode_starts: Whether the observations correspond to new episodes + or not (we reset the lstm states in that case). + :return: estimated value, log likelihood of taking those actions + and entropy of the action distribution. + """ + # Preprocess the observation if needed + features = self.extract_features(obs) + latent_pi, _ = self.lstm_actor(features) + + if self.lstm_critic is not None: + latent_vf, _ = self.lstm_critic(features) + elif self.shared_lstm: + latent_vf = latent_pi.detach() + else: + latent_vf = self.critic(features) + + latent_pi = self.mlp_extractor.forward_actor(latent_pi) + latent_vf = self.mlp_extractor.forward_critic(latent_vf) + + distribution = self._get_action_dist_from_latent(latent_pi) + log_prob = distribution.distribution.log_prob(actions).sum(dim=-1) + log_prob = log_prob.reshape(log_prob.shape + (1,)) + values = self.value_net(latent_vf) + return values, log_prob, distribution.entropy() + def _predict( self, observation: th.Tensor, diff --git a/sb3_contrib/common/recurrent/type_aliases.py b/sb3_contrib/common/recurrent/type_aliases.py index 1ae9a087..19e94211 100644 --- a/sb3_contrib/common/recurrent/type_aliases.py +++ b/sb3_contrib/common/recurrent/type_aliases.py @@ -31,3 +31,12 @@ class RecurrentDictRolloutBufferSamples(RecurrentRolloutBufferSamples): lstm_states: RNNStates episode_starts: th.Tensor mask: th.Tensor + + +class RecurrentDictRolloutBufferSequenceSamples(NamedTuple): + observations: TensorDict + actions: th.Tensor + old_values: th.Tensor + old_log_prob: th.Tensor + advantages: th.Tensor + returns: th.Tensor \ No newline at end of file diff --git a/sb3_contrib/ppo_recurrent/ppo_recurrent.py b/sb3_contrib/ppo_recurrent/ppo_recurrent.py index 965e0080..8f40ea9d 100644 --- a/sb3_contrib/ppo_recurrent/ppo_recurrent.py +++ b/sb3_contrib/ppo_recurrent/ppo_recurrent.py @@ -15,7 +15,7 @@ from stable_baselines3.common.utils import explained_variance, get_schedule_fn, obs_as_tensor, safe_mean from stable_baselines3.common.vec_env import VecEnv -from sb3_contrib.common.recurrent.buffers import RecurrentDictRolloutBuffer, RecurrentRolloutBuffer +from sb3_contrib.common.recurrent.buffers import RecurrentDictRolloutBuffer, RecurrentRolloutBuffer, RecurrentSequenceDictRolloutBuffer from sb3_contrib.common.recurrent.policies import RecurrentActorCriticPolicy from sb3_contrib.common.recurrent.type_aliases import RNNStates from sb3_contrib.ppo_recurrent.policies import CnnLstmPolicy, MlpLstmPolicy, MultiInputLstmPolicy @@ -82,6 +82,7 @@ def __init__( learning_rate: Union[float, Schedule] = 3e-4, n_steps: int = 128, batch_size: Optional[int] = 128, + whole_sequences: bool = False, n_epochs: int = 10, gamma: float = 0.99, gae_lambda: float = 0.95, @@ -130,6 +131,7 @@ def __init__( ) self.batch_size = batch_size + self.whole_sequences = whole_sequences self.n_epochs = n_epochs self.clip_range = clip_range self.clip_range_vf = clip_range_vf @@ -144,9 +146,14 @@ def _setup_model(self) -> None: self._setup_lr_schedule() self.set_random_seed(self.seed) - buffer_cls = ( - RecurrentDictRolloutBuffer if isinstance(self.observation_space, gym.spaces.Dict) else RecurrentRolloutBuffer - ) + if self.whole_sequences: # TODO + buffer_cls = ( + RecurrentSequenceDictRolloutBuffer if isinstance(self.observation_space, gym.spaces.Dict) else RecurrentRolloutBuffer + ) + else: + buffer_cls = ( + RecurrentDictRolloutBuffer if isinstance(self.observation_space, gym.spaces.Dict) else RecurrentRolloutBuffer + ) self.policy = self.policy_class( self.observation_space, @@ -219,7 +226,7 @@ def collect_rollouts( collected, False if callback terminated rollout prematurely. """ assert isinstance( - rollout_buffer, (RecurrentRolloutBuffer, RecurrentDictRolloutBuffer) + rollout_buffer, (RecurrentRolloutBuffer, RecurrentDictRolloutBuffer, RecurrentSequenceDictRolloutBuffer) ), f"{rollout_buffer} doesn't support recurrent policy" assert self._last_obs is not None, "No previous observation was provided" @@ -315,10 +322,145 @@ def collect_rollouts( return True + def train_whole_sequences(self) -> None: + """ + Update policy using the currently gathered rollout buffer. + """ + # Switch to train mode (this affects batch norm / dropout) + self.policy.set_training_mode(True) + # Update optimizer learning rate + self._update_learning_rate(self.policy.optimizer) + # Compute current clip range + clip_range = self.clip_range(self._current_progress_remaining) + # Optional: clip range for the value function + if self.clip_range_vf is not None: + clip_range_vf = self.clip_range_vf(self._current_progress_remaining) + + entropy_losses = [] + pg_losses, value_losses = [], [] + clip_fractions = [] + + continue_training = True + + # train for n_epochs epochs + for epoch in range(self.n_epochs): + approx_kl_divs = [] + # Do a complete pass on the rollout buffer + for rollout_data in self.rollout_buffer.get(self.batch_size): + actions = rollout_data.actions + if isinstance(self.action_space, spaces.Discrete): + # Convert discrete action from float to long + actions = rollout_data.actions.long().flatten() + + # Convert mask from float to bool + # mask = rollout_data.mask > 1e-8 + + # Re-sample the noise matrix because the log_std has changed + if self.use_sde: + self.policy.reset_noise(self.batch_size) + + values, log_prob, entropy = self.policy.evaluate_actions_whole_sequence( + rollout_data.observations, + actions, + ) + + values = values.flatten() + # Normalize advantage + advantages = rollout_data.advantages + if self.normalize_advantage: + advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) + + # ratio between old and new policy, should be one at the first iteration + ratio = th.exp(log_prob - rollout_data.old_log_prob) + + # clipped surrogate loss + policy_loss_1 = advantages * ratio + policy_loss_2 = advantages * th.clamp(ratio, 1 - clip_range, 1 + clip_range) + policy_loss = -th.mean(th.min(policy_loss_1, policy_loss_2)) + + # Logging + pg_losses.append(policy_loss.item()) + clip_fraction = th.mean((th.abs(ratio - 1) > clip_range).float()).item() + clip_fractions.append(clip_fraction) + + if self.clip_range_vf is None: + # No clipping + values_pred = values + else: + # Clip the different between old and new value + # NOTE: this depends on the reward scaling + values_pred = rollout_data.old_values + th.clamp( + values - rollout_data.old_values, -clip_range_vf, clip_range_vf + ) + # Value loss using the TD(gae_lambda) target + # Mask padded sequences + value_loss = th.mean(((rollout_data.returns - values_pred) ** 2)) + + value_losses.append(value_loss.item()) + + # Entropy loss favor exploration + if entropy is None: + # Approximate entropy when no analytical form + entropy_loss = -th.mean(-log_prob) + else: + entropy_loss = -th.mean(entropy) + + entropy_losses.append(entropy_loss.item()) + + loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss + + # Calculate approximate form of reverse KL Divergence for early stopping + # see issue #417: https://github.com/DLR-RM/stable-baselines3/issues/417 + # and discussion in PR #419: https://github.com/DLR-RM/stable-baselines3/pull/419 + # and Schulman blog: http://joschu.net/blog/kl-approx.html + with th.no_grad(): + log_ratio = log_prob - rollout_data.old_log_prob + approx_kl_div = th.mean(((th.exp(log_ratio) - 1) - log_ratio)).cpu().numpy() + approx_kl_divs.append(approx_kl_div) + + if self.target_kl is not None and approx_kl_div > 1.5 * self.target_kl: + continue_training = False + if self.verbose >= 1: + print(f"Early stopping at step {epoch} due to reaching max kl: {approx_kl_div:.2f}") + break + + # Optimization step + self.policy.optimizer.zero_grad() + loss.backward() + # Clip grad norm + th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm) + self.policy.optimizer.step() + + if not continue_training: + break + + self._n_updates += self.n_epochs + explained_var = explained_variance(self.rollout_buffer.values.flatten(), self.rollout_buffer.returns.flatten()) + + # Logs + self.logger.record("train/entropy_loss", np.mean(entropy_losses)) + self.logger.record("train/policy_gradient_loss", np.mean(pg_losses)) + self.logger.record("train/value_loss", np.mean(value_losses)) + self.logger.record("train/approx_kl", np.mean(approx_kl_divs)) + self.logger.record("train/clip_fraction", np.mean(clip_fractions)) + self.logger.record("train/loss", loss.item()) + self.logger.record("train/explained_variance", explained_var) + if hasattr(self.policy, "log_std"): + self.logger.record("train/std", th.exp(self.policy.log_std).mean().item()) + + self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard") + self.logger.record("train/clip_range", clip_range) + if self.clip_range_vf is not None: + self.logger.record("train/clip_range_vf", clip_range_vf) + def train(self) -> None: """ Update policy using the currently gathered rollout buffer. """ + if self.whole_sequences == True: + self.train_whole_sequences() + return + # Switch to train mode (this affects batch norm / dropout) self.policy.set_training_mode(True) # Update optimizer learning rate From ff8cb9d0e2ddbcbba626c55be96b130ebe9a47fe Mon Sep 17 00:00:00 2001 From: b-x-ubuntu Date: Mon, 14 Nov 2022 22:03:33 +0100 Subject: [PATCH 02/17] added masking and fixed some bugs --- sb3_contrib/common/recurrent/buffers.py | 26 +++++++++++--------- sb3_contrib/common/recurrent/policies.py | 9 +++++-- sb3_contrib/common/recurrent/type_aliases.py | 3 ++- sb3_contrib/ppo_recurrent/ppo_recurrent.py | 18 +++++++------- 4 files changed, 32 insertions(+), 24 deletions(-) diff --git a/sb3_contrib/common/recurrent/buffers.py b/sb3_contrib/common/recurrent/buffers.py index 5c81840f..8467f590 100644 --- a/sb3_contrib/common/recurrent/buffers.py +++ b/sb3_contrib/common/recurrent/buffers.py @@ -454,19 +454,21 @@ def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentDictRollou for indices in batch_sampler: obs_batch = {} for key in self.observations: - obs_batch[key] = pad_sequence([th.Tensor(self.observations[key][self.episode_start_indices[i]:self.episode_start_indices[i+1]]) for i in indices], batch_first=True) + obs_batch[key] = pad_sequence([th.tensor(self.observations[key][self.episode_start_indices[i]:self.episode_start_indices[i+1]], device=self.device) for i in indices], batch_first=True) - actions_batch = pad_sequence([th.Tensor(self.actions[self.episode_start_indices[i]:self.episode_start_indices[i+1]]) for i in indices], batch_first=True) - old_values_batch = pad_sequence([th.Tensor(self.values[self.episode_start_indices[i]:self.episode_start_indices[i+1]]) for i in indices], batch_first=True) - old_log_probs_batch = pad_sequence([th.Tensor(self.log_probs[self.episode_start_indices[i]:self.episode_start_indices[i+1]]) for i in indices], batch_first=True) - advantages_batch = pad_sequence([th.Tensor(self.advantages[self.episode_start_indices[i]:self.episode_start_indices[i+1]]) for i in indices], batch_first=True) - returns_batch = pad_sequence([th.Tensor(self.returns[self.episode_start_indices[i]:self.episode_start_indices[i+1]]) for i in indices], batch_first=True) + actions_batch = pad_sequence([th.tensor(self.actions[self.episode_start_indices[i]:self.episode_start_indices[i+1]], device=self.device) for i in indices], batch_first=True) + old_values_batch = pad_sequence([th.tensor(self.values[self.episode_start_indices[i]:self.episode_start_indices[i+1]], device=self.device) for i in indices], batch_first=True) + old_log_probs_batch = pad_sequence([th.tensor(self.log_probs[self.episode_start_indices[i]:self.episode_start_indices[i+1]], device=self.device) for i in indices], batch_first=True) + advantages_batch = pad_sequence([th.tensor(self.advantages[self.episode_start_indices[i]:self.episode_start_indices[i+1]], device=self.device) for i in indices], batch_first=True) + returns_batch = pad_sequence([th.tensor(self.returns[self.episode_start_indices[i]:self.episode_start_indices[i+1]], device=self.device) for i in indices], batch_first=True) + masks_batch = pad_sequence([th.ones_like(r) for r in returns_batch], batch_first=True) yield RecurrentDictRolloutBufferSequenceSamples( - observations=obs_batch, - actions=actions_batch, - old_values=old_values_batch, - old_log_prob=old_log_probs_batch, - advantages=advantages_batch, - returns=returns_batch + observations={key:th.swapaxes(obs_batch[key], 0, 1) for key in obs_batch}, + actions=th.swapaxes(actions_batch, 0, 1), + old_values=th.swapaxes(old_values_batch, 0, 1), + old_log_prob=th.swapaxes(old_log_probs_batch, 0, 1), + advantages=th.swapaxes(advantages_batch, 0, 1), + returns=th.swapaxes(returns_batch, 0, 1), + masks=th.swapaxes(masks_batch, 0, 1) ) \ No newline at end of file diff --git a/sb3_contrib/common/recurrent/policies.py b/sb3_contrib/common/recurrent/policies.py index 690ec8fb..99bdaffa 100644 --- a/sb3_contrib/common/recurrent/policies.py +++ b/sb3_contrib/common/recurrent/policies.py @@ -367,11 +367,16 @@ def evaluate_actions_whole_sequence( latent_pi = self.mlp_extractor.forward_actor(latent_pi) latent_vf = self.mlp_extractor.forward_critic(latent_vf) + values = self.value_net(latent_vf) + distribution = self._get_action_dist_from_latent(latent_pi) log_prob = distribution.distribution.log_prob(actions).sum(dim=-1) log_prob = log_prob.reshape(log_prob.shape + (1,)) - values = self.value_net(latent_vf) - return values, log_prob, distribution.entropy() + + entropy = distribution.distribution.entropy().sum(dim=-1) + entropy = entropy.reshape(entropy.shape + (1,)) + + return values, log_prob, entropy def _predict( self, diff --git a/sb3_contrib/common/recurrent/type_aliases.py b/sb3_contrib/common/recurrent/type_aliases.py index 19e94211..ac505c34 100644 --- a/sb3_contrib/common/recurrent/type_aliases.py +++ b/sb3_contrib/common/recurrent/type_aliases.py @@ -39,4 +39,5 @@ class RecurrentDictRolloutBufferSequenceSamples(NamedTuple): old_values: th.Tensor old_log_prob: th.Tensor advantages: th.Tensor - returns: th.Tensor \ No newline at end of file + returns: th.Tensor + masks: th.Tensor \ No newline at end of file diff --git a/sb3_contrib/ppo_recurrent/ppo_recurrent.py b/sb3_contrib/ppo_recurrent/ppo_recurrent.py index 8f40ea9d..be1548a5 100644 --- a/sb3_contrib/ppo_recurrent/ppo_recurrent.py +++ b/sb3_contrib/ppo_recurrent/ppo_recurrent.py @@ -353,7 +353,7 @@ def train_whole_sequences(self) -> None: actions = rollout_data.actions.long().flatten() # Convert mask from float to bool - # mask = rollout_data.mask > 1e-8 + mask = rollout_data.masks > 1e-8 # Re-sample the noise matrix because the log_std has changed if self.use_sde: @@ -364,11 +364,11 @@ def train_whole_sequences(self) -> None: actions, ) - values = values.flatten() + # values = values.flatten() # ?? # Normalize advantage advantages = rollout_data.advantages if self.normalize_advantage: - advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) + advantages = (advantages - advantages[mask].mean()) / (advantages[mask].std() + 1e-8) # ratio between old and new policy, should be one at the first iteration ratio = th.exp(log_prob - rollout_data.old_log_prob) @@ -376,11 +376,11 @@ def train_whole_sequences(self) -> None: # clipped surrogate loss policy_loss_1 = advantages * ratio policy_loss_2 = advantages * th.clamp(ratio, 1 - clip_range, 1 + clip_range) - policy_loss = -th.mean(th.min(policy_loss_1, policy_loss_2)) + policy_loss = -th.mean(th.min(policy_loss_1, policy_loss_2)[mask]) # Logging pg_losses.append(policy_loss.item()) - clip_fraction = th.mean((th.abs(ratio - 1) > clip_range).float()).item() + clip_fraction = th.mean((th.abs(ratio - 1) > clip_range).float()[mask]).item() clip_fractions.append(clip_fraction) if self.clip_range_vf is None: @@ -394,16 +394,16 @@ def train_whole_sequences(self) -> None: ) # Value loss using the TD(gae_lambda) target # Mask padded sequences - value_loss = th.mean(((rollout_data.returns - values_pred) ** 2)) + value_loss = th.mean(((rollout_data.returns - values_pred) ** 2)[mask]) value_losses.append(value_loss.item()) # Entropy loss favor exploration if entropy is None: # Approximate entropy when no analytical form - entropy_loss = -th.mean(-log_prob) + entropy_loss = -th.mean(-log_prob[mask]) else: - entropy_loss = -th.mean(entropy) + entropy_loss = -th.mean(entropy[mask]) entropy_losses.append(entropy_loss.item()) @@ -415,7 +415,7 @@ def train_whole_sequences(self) -> None: # and Schulman blog: http://joschu.net/blog/kl-approx.html with th.no_grad(): log_ratio = log_prob - rollout_data.old_log_prob - approx_kl_div = th.mean(((th.exp(log_ratio) - 1) - log_ratio)).cpu().numpy() + approx_kl_div = th.mean(((th.exp(log_ratio) - 1) - log_ratio)[mask]).cpu().numpy() approx_kl_divs.append(approx_kl_div) if self.target_kl is not None and approx_kl_div > 1.5 * self.target_kl: From cfc0f7082f25db3cb0e3c19dfcea934e1a398dfa Mon Sep 17 00:00:00 2001 From: b-x-ubuntu Date: Wed, 23 Nov 2022 15:31:27 +0100 Subject: [PATCH 03/17] implemented for non dict obs + fixed bugs + added basic script showing speed improvement --- sb3_contrib/common/recurrent/buffers.py | 89 ++++++++++++++++++-- sb3_contrib/common/recurrent/policies.py | 5 +- sb3_contrib/common/recurrent/type_aliases.py | 10 +++ sb3_contrib/ppo_recurrent/ppo_recurrent.py | 14 +-- whole_sequence_speed_test.py | 55 ++++++++++++ 5 files changed, 157 insertions(+), 16 deletions(-) create mode 100644 whole_sequence_speed_test.py diff --git a/sb3_contrib/common/recurrent/buffers.py b/sb3_contrib/common/recurrent/buffers.py index 8467f590..71ab29de 100644 --- a/sb3_contrib/common/recurrent/buffers.py +++ b/sb3_contrib/common/recurrent/buffers.py @@ -13,6 +13,7 @@ RecurrentDictRolloutBufferSamples, RecurrentRolloutBufferSamples, RNNStates, + RecurrentRolloutBufferSequenceSamples, RecurrentDictRolloutBufferSequenceSamples, ) @@ -387,10 +388,87 @@ def _get_samples( ) +class RecurrentSequenceRolloutBuffer(RecurrentRolloutBuffer): + """ + Sequence Rollout buffer used in on-policy algorithms like A2C/PPO. + Overrides the RecurrentRolloutBuffer to yield 3d batches of whole sequences + + :param buffer_size: Max number of element in the buffer + :param observation_space: Observation space + :param action_space: Action space + :param hidden_state_shape: Shape of the buffer that will collect lstm states + :param device: PyTorch device + :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator + Equivalent to classic advantage when set to 1. + :param gamma: Discount factor + :param n_envs: Number of parallel environments + """ + + def __init__( + self, + buffer_size: int, + observation_space: spaces.Space, + action_space: spaces.Space, + hidden_state_shape: Tuple[int, int, int, int], + device: Union[th.device, str] = "auto", + gae_lambda: float = 1, + gamma: float = 0.99, + n_envs: int = 1, + ): + self.hidden_state_shape = hidden_state_shape + self.seq_start_indices, self.seq_end_indices = None, None + super().__init__(buffer_size, observation_space, action_space, hidden_state_shape, device, gae_lambda, gamma, n_envs=n_envs) + + def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentRolloutBufferSequenceSamples, None, None]: + assert self.full, "Rollout buffer must be full before sampling from it" + # Prepare the data + if not self.generator_ready: + for tensor in [ + "observations", + "actions", + "values", + "log_probs", + "advantages", + "returns", + "hidden_states_pi", + "cell_states_pi", + "hidden_states_vf", + "cell_states_vf", + "episode_starts", + ]: + self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor]) + + self.episode_start_indices = np.where(self.episode_starts == 1)[0] + self.generator_ready = True + + random_indices = SubsetRandomSampler(range(len(self.episode_start_indices)-1)) # dropping the last one to make indexing the np arrays much simpler + batch_sampler = BatchSampler(random_indices, batch_size, drop_last=False) + + # yields batches of whole sequences, shape: (sequence_length, batch_size, data_length)) + for indices in batch_sampler: + obs_batch = pad_sequence([th.tensor(self.observations[self.episode_start_indices[i]:self.episode_start_indices[i+1]], device=self.device) for i in indices], batch_first=True) + actions_batch = pad_sequence([th.tensor(self.actions[self.episode_start_indices[i]:self.episode_start_indices[i+1]], device=self.device) for i in indices], batch_first=True) + old_values_batch = pad_sequence([th.tensor(self.values[self.episode_start_indices[i]:self.episode_start_indices[i+1]], device=self.device) for i in indices], batch_first=True) + old_log_probs_batch = pad_sequence([th.tensor(self.log_probs[self.episode_start_indices[i]:self.episode_start_indices[i+1]], device=self.device) for i in indices], batch_first=True) + advantages_batch = pad_sequence([th.tensor(self.advantages[self.episode_start_indices[i]:self.episode_start_indices[i+1]], device=self.device) for i in indices], batch_first=True) + returns_batch = pad_sequence([th.tensor(self.returns[self.episode_start_indices[i]:self.episode_start_indices[i+1]], device=self.device) for i in indices], batch_first=True) + masks_batch = pad_sequence([th.ones_like(r) for r in returns_batch], batch_first=True) + + yield RecurrentRolloutBufferSequenceSamples( + observations=th.swapaxes(obs_batch, 0, 1), + actions=th.swapaxes(actions_batch, 0, 1), + old_values=th.swapaxes(old_values_batch, 0, 1), + old_log_prob=th.swapaxes(old_log_probs_batch, 0, 1), + advantages=th.swapaxes(advantages_batch, 0, 1), + returns=th.swapaxes(returns_batch, 0, 1), + masks=th.swapaxes(masks_batch, 0, 1) + ) + + class RecurrentSequenceDictRolloutBuffer(RecurrentDictRolloutBuffer): """ Sequence Dict Rollout buffer used in on-policy algorithms like A2C/PPO. - Overrides the DictRecurrentRolloutBuffer to yield batches of whole sequences + Overrides the DictRecurrentRolloutBuffer to yield 3d batches of whole sequences :param buffer_size: Max number of element in the buffer :param observation_space: Observation space @@ -418,15 +496,10 @@ def __init__( self.seq_start_indices, self.seq_end_indices = None, None super().__init__(buffer_size, observation_space, action_space, hidden_state_shape, device, gae_lambda, gamma, n_envs=n_envs) - def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentDictRolloutBufferSamples, None, None]: + def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentDictRolloutBufferSequenceSamples, None, None]: assert self.full, "Rollout buffer must be full before sampling from it" # Prepare the data if not self.generator_ready: - # hidden_state_shape = (self.n_steps, lstm.num_layers, self.n_envs, lstm.hidden_size) - # swap first to (self.n_steps, self.n_envs, lstm.num_layers, lstm.hidden_size) - for tensor in ["hidden_states_pi", "cell_states_pi", "hidden_states_vf", "cell_states_vf"]: - self.__dict__[tensor] = self.__dict__[tensor].swapaxes(1, 2) - for key, obs in self.observations.items(): self.observations[key] = self.swap_and_flatten(obs) @@ -450,7 +523,7 @@ def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentDictRollou random_indices = SubsetRandomSampler(range(len(self.episode_start_indices)-1)) # dropping the last one to make indexing the np arrays much simpler batch_sampler = BatchSampler(random_indices, batch_size, drop_last=False) - # yields batches of whole sequences, shape: (batch_size, sequence_length, data_length) + # yields batches of whole sequences, shape: (sequence_length, batch_size, data_length) for indices in batch_sampler: obs_batch = {} for key in self.observations: diff --git a/sb3_contrib/common/recurrent/policies.py b/sb3_contrib/common/recurrent/policies.py index 99bdaffa..18ea4c88 100644 --- a/sb3_contrib/common/recurrent/policies.py +++ b/sb3_contrib/common/recurrent/policies.py @@ -354,7 +354,10 @@ def evaluate_actions_whole_sequence( and entropy of the action distribution. """ # Preprocess the observation if needed - features = self.extract_features(obs) + if self.features_extractor_class == FlattenExtractor: # temporary fix to disable the flattening that stable_baselines3 feature extractors do by default + features = obs + else: + features = self.extract_features(obs) latent_pi, _ = self.lstm_actor(features) if self.lstm_critic is not None: diff --git a/sb3_contrib/common/recurrent/type_aliases.py b/sb3_contrib/common/recurrent/type_aliases.py index ac505c34..16dbd23b 100644 --- a/sb3_contrib/common/recurrent/type_aliases.py +++ b/sb3_contrib/common/recurrent/type_aliases.py @@ -33,6 +33,16 @@ class RecurrentDictRolloutBufferSamples(RecurrentRolloutBufferSamples): mask: th.Tensor +class RecurrentRolloutBufferSequenceSamples(NamedTuple): + observations: th.Tensor + actions: th.Tensor + old_values: th.Tensor + old_log_prob: th.Tensor + advantages: th.Tensor + returns: th.Tensor + masks: th.Tensor + + class RecurrentDictRolloutBufferSequenceSamples(NamedTuple): observations: TensorDict actions: th.Tensor diff --git a/sb3_contrib/ppo_recurrent/ppo_recurrent.py b/sb3_contrib/ppo_recurrent/ppo_recurrent.py index be1548a5..7fe86e46 100644 --- a/sb3_contrib/ppo_recurrent/ppo_recurrent.py +++ b/sb3_contrib/ppo_recurrent/ppo_recurrent.py @@ -15,7 +15,7 @@ from stable_baselines3.common.utils import explained_variance, get_schedule_fn, obs_as_tensor, safe_mean from stable_baselines3.common.vec_env import VecEnv -from sb3_contrib.common.recurrent.buffers import RecurrentDictRolloutBuffer, RecurrentRolloutBuffer, RecurrentSequenceDictRolloutBuffer +from sb3_contrib.common.recurrent.buffers import RecurrentDictRolloutBuffer, RecurrentRolloutBuffer, RecurrentSequenceDictRolloutBuffer, RecurrentSequenceRolloutBuffer from sb3_contrib.common.recurrent.policies import RecurrentActorCriticPolicy from sb3_contrib.common.recurrent.type_aliases import RNNStates from sb3_contrib.ppo_recurrent.policies import CnnLstmPolicy, MlpLstmPolicy, MultiInputLstmPolicy @@ -146,9 +146,10 @@ def _setup_model(self) -> None: self._setup_lr_schedule() self.set_random_seed(self.seed) - if self.whole_sequences: # TODO + # 3d batches of whole sequences or 2d batches of constant size + if self.whole_sequences: buffer_cls = ( - RecurrentSequenceDictRolloutBuffer if isinstance(self.observation_space, gym.spaces.Dict) else RecurrentRolloutBuffer + RecurrentSequenceDictRolloutBuffer if isinstance(self.observation_space, gym.spaces.Dict) else RecurrentSequenceRolloutBuffer ) else: buffer_cls = ( @@ -226,7 +227,7 @@ def collect_rollouts( collected, False if callback terminated rollout prematurely. """ assert isinstance( - rollout_buffer, (RecurrentRolloutBuffer, RecurrentDictRolloutBuffer, RecurrentSequenceDictRolloutBuffer) + rollout_buffer, (RecurrentRolloutBuffer, RecurrentDictRolloutBuffer, RecurrentSequenceRolloutBuffer, RecurrentSequenceDictRolloutBuffer) ), f"{rollout_buffer} doesn't support recurrent policy" assert self._last_obs is not None, "No previous observation was provided" @@ -324,7 +325,7 @@ def collect_rollouts( def train_whole_sequences(self) -> None: """ - Update policy using the currently gathered rollout buffer. + Update policy using the currently gathered rollout buffer but do it on 3d batches of whole sequences. """ # Switch to train mode (this affects batch norm / dropout) self.policy.set_training_mode(True) @@ -350,7 +351,7 @@ def train_whole_sequences(self) -> None: actions = rollout_data.actions if isinstance(self.action_space, spaces.Discrete): # Convert discrete action from float to long - actions = rollout_data.actions.long().flatten() + actions = rollout_data.actions.long() # Convert mask from float to bool mask = rollout_data.masks > 1e-8 @@ -364,7 +365,6 @@ def train_whole_sequences(self) -> None: actions, ) - # values = values.flatten() # ?? # Normalize advantage advantages = rollout_data.advantages if self.normalize_advantage: diff --git a/whole_sequence_speed_test.py b/whole_sequence_speed_test.py new file mode 100644 index 00000000..3e86149f --- /dev/null +++ b/whole_sequence_speed_test.py @@ -0,0 +1,55 @@ +from sb3_contrib import RecurrentPPO +from stable_baselines3.common.vec_env import SubprocVecEnv, VecMonitor + +import gym + + +def make_env(**kwargs): + def _init(): + env = gym.make(**kwargs) + return env + return _init + +def get_vectorized_envs(n_cpus, **kwargs): + envs_no_log = SubprocVecEnv([make_env(**kwargs) for _ in range(n_cpus)]) + envs = VecMonitor(envs_no_log) + return envs + + +if __name__ == "__main__": + envs = get_vectorized_envs(n_cpus=32, id="BipedalWalker-v3") + + model = RecurrentPPO( + "MlpLstmPolicy", + envs, + n_steps=2048, + batch_size=64, + gae_lambda=0.95, + gamma=0.999, + whole_sequences=False, + n_epochs=10, + ent_coef=0.0, + learning_rate=0.0003, + clip_range=0.18, + tensorboard_log="temp/", + verbose=1 + ) + model.learn(500000, tb_log_name="speed test sb3-contrib default 2d batching") + + + model = RecurrentPPO( + "MlpLstmPolicy", + envs, + n_steps=2048, + batch_size=4, + gae_lambda=0.95, + gamma=0.999, + whole_sequences=True, + n_epochs=10, + ent_coef=0.0, + learning_rate=0.0003, + clip_range=0.18, + tensorboard_log="temp/", + verbose=1 + ) + model.learn(500000, tb_log_name="speed test sb3-contrib whole sequence 3d batching") \ No newline at end of file From 8b609544e984393f9f7fd83bfd4a9afa895ed1ec Mon Sep 17 00:00:00 2001 From: b-x-ubuntu Date: Mon, 28 Nov 2022 16:00:40 +0100 Subject: [PATCH 04/17] bug fix episode starts after first update --- .gitignore | 2 ++ sb3_contrib/common/recurrent/buffers.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/.gitignore b/.gitignore index 9f548898..2edfeeb4 100644 --- a/.gitignore +++ b/.gitignore @@ -46,3 +46,5 @@ src *.prof MUJOCO_LOG.TXT + +temp/ diff --git a/sb3_contrib/common/recurrent/buffers.py b/sb3_contrib/common/recurrent/buffers.py index 71ab29de..c9513fe1 100644 --- a/sb3_contrib/common/recurrent/buffers.py +++ b/sb3_contrib/common/recurrent/buffers.py @@ -423,6 +423,7 @@ def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentRolloutBuf assert self.full, "Rollout buffer must be full before sampling from it" # Prepare the data if not self.generator_ready: + self.episode_starts[0, :] = 1 for tensor in [ "observations", "actions", @@ -500,6 +501,7 @@ def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentDictRollou assert self.full, "Rollout buffer must be full before sampling from it" # Prepare the data if not self.generator_ready: + self.episode_starts[0, :] = 1 for key, obs in self.observations.items(): self.observations[key] = self.swap_and_flatten(obs) From 59f4a7ad9d3946d150b9ecb1ceabb93a4783dc38 Mon Sep 17 00:00:00 2001 From: b-x-ubuntu Date: Mon, 28 Nov 2022 16:01:31 +0100 Subject: [PATCH 05/17] updated testing script --- whole_sequence_speed_test.py | 134 ++++++++++++++++++++++++++++++++--- 1 file changed, 123 insertions(+), 11 deletions(-) diff --git a/whole_sequence_speed_test.py b/whole_sequence_speed_test.py index 3e86149f..4778c084 100644 --- a/whole_sequence_speed_test.py +++ b/whole_sequence_speed_test.py @@ -1,29 +1,70 @@ +import gym +import numpy as np +import torch.nn as nn from sb3_contrib import RecurrentPPO -from stable_baselines3.common.vec_env import SubprocVecEnv, VecMonitor +from stable_baselines3.common.vec_env import SubprocVecEnv, VecMonitor, VecNormalize -import gym +class MaskVelocityWrapper(gym.ObservationWrapper): + """ + Gym environment observation wrapper used to mask velocity terms in + observations. The intention is the make the MDP partially observable. + Adapted from https://github.com/LiuWenlin595/FinalProject. + :param env: Gym environment + """ + + # Supported envs + velocity_indices = { + "CartPole-v1": np.array([1, 3]), + "MountainCar-v0": np.array([1]), + "MountainCarContinuous-v0": np.array([1]), + "Pendulum-v1": np.array([2]), + "LunarLander-v2": np.array([2, 3, 5]), + "LunarLanderContinuous-v2": np.array([2, 3, 5]), + } + + def __init__(self, env: gym.Env): + super().__init__(env) + + env_id: str = env.unwrapped.spec.id + # By default no masking + self.mask = np.ones_like((env.observation_space.sample())) + try: + # Mask velocity + self.mask[self.velocity_indices[env_id]] = 0.0 + except KeyError: + raise NotImplementedError(f"Velocity masking not implemented for {env_id}") + + def observation(self, observation: np.ndarray) -> np.ndarray: + return observation * self.mask -def make_env(**kwargs): +def make_env(mask_vel=False, **kwargs): def _init(): env = gym.make(**kwargs) + if mask_vel: + env = MaskVelocityWrapper(env) return env return _init + def get_vectorized_envs(n_cpus, **kwargs): envs_no_log = SubprocVecEnv([make_env(**kwargs) for _ in range(n_cpus)]) - envs = VecMonitor(envs_no_log) + envs = VecNormalize(VecMonitor(envs_no_log)) return envs if __name__ == "__main__": - envs = get_vectorized_envs(n_cpus=32, id="BipedalWalker-v3") + ############################################################ + # BipedalWalker-v3 + ############################################################ + n_cpus = 32 + envs = get_vectorized_envs(n_cpus=n_cpus, id="BipedalWalker-v3") model = RecurrentPPO( "MlpLstmPolicy", envs, - n_steps=2048, - batch_size=64, + n_steps=256, + batch_size=256, gae_lambda=0.95, gamma=0.999, whole_sequences=False, @@ -31,25 +72,96 @@ def get_vectorized_envs(n_cpus, **kwargs): ent_coef=0.0, learning_rate=0.0003, clip_range=0.18, + policy_kwargs={ + "ortho_init": False, + "activation_fn": nn.ReLU, + "lstm_hidden_size": 64, + "enable_critic_lstm": True, + "net_arch": [dict(pi=[64], vf=[64])] + }, tensorboard_log="temp/", verbose=1 ) - model.learn(500000, tb_log_name="speed test sb3-contrib default 2d batching") + model.learn(5e6, tb_log_name="BipedalWalker-v3_sb3_standard") model = RecurrentPPO( "MlpLstmPolicy", envs, - n_steps=2048, + n_steps=256, batch_size=4, gae_lambda=0.95, gamma=0.999, - whole_sequences=True, + whole_sequences=True, # This sets use of whole sequence batching n_epochs=10, ent_coef=0.0, learning_rate=0.0003, clip_range=0.18, + policy_kwargs={ + "ortho_init": False, + "activation_fn": nn.ReLU, + "lstm_hidden_size": 64, + "enable_critic_lstm": True, + "net_arch": [dict(pi=[64], vf=[64])] + }, + tensorboard_log="temp/", + verbose=1 + ) + model.learn(5e6, tb_log_name="BipedalWalker-v3_whole_sequences") + + + ############################################################ + # PendulumNoVel + ############################################################ + + n_cpus = 4 + envs = get_vectorized_envs(n_cpus=n_cpus, id="Pendulum-v1", mask_vel=True) + + model = RecurrentPPO( + "MlpLstmPolicy", + envs, + n_steps=1024, + # batch_size=256, + gae_lambda=0.95, + gamma=0.9, + whole_sequences=False, + n_epochs=10, + ent_coef=0.0, + learning_rate=0.001, + clip_range=0.2, + policy_kwargs={ + "ortho_init": False, + "activation_fn": nn.ReLU, + "lstm_hidden_size": 64, + "enable_critic_lstm": True, + "net_arch": [dict(pi=[64], vf=[64])] + }, tensorboard_log="temp/", verbose=1 ) - model.learn(500000, tb_log_name="speed test sb3-contrib whole sequence 3d batching") \ No newline at end of file + model.learn(2e5, tb_log_name="PendulumNoVel-v1_sb3_standard") + + for batch_size in [2, 4, 8]: + model = RecurrentPPO( + "MlpLstmPolicy", + envs, + n_steps=1024, + batch_size=batch_size, + gae_lambda=0.95, + gamma=0.9, + whole_sequences=True, # This sets use of whole sequence batching + n_epochs=10, + ent_coef=0.0, + learning_rate=0.001, + clip_range=0.2, + policy_kwargs={ + "ortho_init": False, + "activation_fn": nn.ReLU, + "lstm_hidden_size": 64, + "enable_critic_lstm": True, + "net_arch": [dict(pi=[64], vf=[64])] + }, + tensorboard_log="temp/", + verbose=1 + ) + model.learn(2e5, tb_log_name=f"PendulumNoVel-v1_whole_sequences_batch_size{batch_size}") \ No newline at end of file From 9196049543e485001cade9956f0196198689bd6b Mon Sep 17 00:00:00 2001 From: b-x-ubuntu Date: Mon, 28 Nov 2022 17:29:20 +0100 Subject: [PATCH 06/17] fixed NaNs due to supersmall batch sizesoccurring in edge case by dropping those batches --- sb3_contrib/common/recurrent/buffers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sb3_contrib/common/recurrent/buffers.py b/sb3_contrib/common/recurrent/buffers.py index c9513fe1..5a130fc8 100644 --- a/sb3_contrib/common/recurrent/buffers.py +++ b/sb3_contrib/common/recurrent/buffers.py @@ -443,7 +443,7 @@ def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentRolloutBuf self.generator_ready = True random_indices = SubsetRandomSampler(range(len(self.episode_start_indices)-1)) # dropping the last one to make indexing the np arrays much simpler - batch_sampler = BatchSampler(random_indices, batch_size, drop_last=False) + batch_sampler = BatchSampler(random_indices, batch_size, drop_last=True) # drop last batch to prevent extremely small batches causing spurious updates # yields batches of whole sequences, shape: (sequence_length, batch_size, data_length)) for indices in batch_sampler: @@ -523,7 +523,7 @@ def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentDictRollou self.generator_ready = True random_indices = SubsetRandomSampler(range(len(self.episode_start_indices)-1)) # dropping the last one to make indexing the np arrays much simpler - batch_sampler = BatchSampler(random_indices, batch_size, drop_last=False) + batch_sampler = BatchSampler(random_indices, batch_size, drop_last=True) # drop last batch to prevent extremely small batches causing spurious updates # yields batches of whole sequences, shape: (sequence_length, batch_size, data_length) for indices in batch_sampler: From d3af84deba984e7ae5f15c51389658f5bc6c67ca Mon Sep 17 00:00:00 2001 From: b-x-ubuntu Date: Sun, 8 Jan 2023 13:58:53 +0100 Subject: [PATCH 07/17] refactoring, made code simpler. --- sb3_contrib/common/recurrent/buffers.py | 68 +++++++++++-------------- 1 file changed, 30 insertions(+), 38 deletions(-) diff --git a/sb3_contrib/common/recurrent/buffers.py b/sb3_contrib/common/recurrent/buffers.py index 0ef3c4d8..addb8110 100644 --- a/sb3_contrib/common/recurrent/buffers.py +++ b/sb3_contrib/common/recurrent/buffers.py @@ -431,38 +431,34 @@ def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentRolloutBuf "log_probs", "advantages", "returns", - "hidden_states_pi", - "cell_states_pi", - "hidden_states_vf", - "cell_states_vf", "episode_starts", ]: self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor]) self.episode_start_indices = np.where(self.episode_starts == 1)[0] self.generator_ready = True - + random_indices = SubsetRandomSampler(range(len(self.episode_start_indices)-1)) # dropping the last one to make indexing the np arrays much simpler batch_sampler = BatchSampler(random_indices, batch_size, drop_last=True) # drop last batch to prevent extremely small batches causing spurious updates # yields batches of whole sequences, shape: (sequence_length, batch_size, data_length)) for indices in batch_sampler: - obs_batch = pad_sequence([th.tensor(self.observations[self.episode_start_indices[i]:self.episode_start_indices[i+1]], device=self.device) for i in indices], batch_first=True) - actions_batch = pad_sequence([th.tensor(self.actions[self.episode_start_indices[i]:self.episode_start_indices[i+1]], device=self.device) for i in indices], batch_first=True) - old_values_batch = pad_sequence([th.tensor(self.values[self.episode_start_indices[i]:self.episode_start_indices[i+1]], device=self.device) for i in indices], batch_first=True) - old_log_probs_batch = pad_sequence([th.tensor(self.log_probs[self.episode_start_indices[i]:self.episode_start_indices[i+1]], device=self.device) for i in indices], batch_first=True) - advantages_batch = pad_sequence([th.tensor(self.advantages[self.episode_start_indices[i]:self.episode_start_indices[i+1]], device=self.device) for i in indices], batch_first=True) - returns_batch = pad_sequence([th.tensor(self.returns[self.episode_start_indices[i]:self.episode_start_indices[i+1]], device=self.device) for i in indices], batch_first=True) - masks_batch = pad_sequence([th.ones_like(r) for r in returns_batch], batch_first=True) + obs_batch = pad_sequence([th.tensor(self.observations[self.episode_start_indices[i]:self.episode_start_indices[i+1]], device=self.device) for i in indices]) + actions_batch = pad_sequence([th.tensor(self.actions[self.episode_start_indices[i]:self.episode_start_indices[i+1]], device=self.device) for i in indices]) + old_values_batch = pad_sequence([th.tensor(self.values[self.episode_start_indices[i]:self.episode_start_indices[i+1]], device=self.device) for i in indices]) + old_log_probs_batch = pad_sequence([th.tensor(self.log_probs[self.episode_start_indices[i]:self.episode_start_indices[i+1]], device=self.device) for i in indices]) + advantages_batch = pad_sequence([th.tensor(self.advantages[self.episode_start_indices[i]:self.episode_start_indices[i+1]], device=self.device) for i in indices]) + returns_batch = pad_sequence([th.tensor(self.returns[self.episode_start_indices[i]:self.episode_start_indices[i+1]], device=self.device) for i in indices]) + masks_batch = pad_sequence([th.ones_like(r) for r in th.swapaxes(returns_batch, 0, 1)]) yield RecurrentRolloutBufferSequenceSamples( - observations=th.swapaxes(obs_batch, 0, 1), - actions=th.swapaxes(actions_batch, 0, 1), - old_values=th.swapaxes(old_values_batch, 0, 1), - old_log_prob=th.swapaxes(old_log_probs_batch, 0, 1), - advantages=th.swapaxes(advantages_batch, 0, 1), - returns=th.swapaxes(returns_batch, 0, 1), - masks=th.swapaxes(masks_batch, 0, 1) + observations=obs_batch, + actions=actions_batch, + old_values=old_values_batch, + old_log_prob=old_log_probs_batch, + advantages=advantages_batch, + returns=returns_batch, + masks=masks_batch ) @@ -511,17 +507,13 @@ def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentDictRollou "log_probs", "advantages", "returns", - "hidden_states_pi", - "cell_states_pi", - "hidden_states_vf", - "cell_states_vf", "episode_starts", ]: self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor]) self.episode_start_indices = np.where(self.episode_starts == 1)[0] self.generator_ready = True - + random_indices = SubsetRandomSampler(range(len(self.episode_start_indices)-1)) # dropping the last one to make indexing the np arrays much simpler batch_sampler = BatchSampler(random_indices, batch_size, drop_last=True) # drop last batch to prevent extremely small batches causing spurious updates @@ -529,21 +521,21 @@ def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentDictRollou for indices in batch_sampler: obs_batch = {} for key in self.observations: - obs_batch[key] = pad_sequence([th.tensor(self.observations[key][self.episode_start_indices[i]:self.episode_start_indices[i+1]], device=self.device) for i in indices], batch_first=True) + obs_batch[key] = pad_sequence([th.tensor(self.observations[key][self.episode_start_indices[i]:self.episode_start_indices[i+1]], device=self.device) for i in indices]) - actions_batch = pad_sequence([th.tensor(self.actions[self.episode_start_indices[i]:self.episode_start_indices[i+1]], device=self.device) for i in indices], batch_first=True) - old_values_batch = pad_sequence([th.tensor(self.values[self.episode_start_indices[i]:self.episode_start_indices[i+1]], device=self.device) for i in indices], batch_first=True) - old_log_probs_batch = pad_sequence([th.tensor(self.log_probs[self.episode_start_indices[i]:self.episode_start_indices[i+1]], device=self.device) for i in indices], batch_first=True) - advantages_batch = pad_sequence([th.tensor(self.advantages[self.episode_start_indices[i]:self.episode_start_indices[i+1]], device=self.device) for i in indices], batch_first=True) - returns_batch = pad_sequence([th.tensor(self.returns[self.episode_start_indices[i]:self.episode_start_indices[i+1]], device=self.device) for i in indices], batch_first=True) - masks_batch = pad_sequence([th.ones_like(r) for r in returns_batch], batch_first=True) + actions_batch = pad_sequence([th.tensor(self.actions[self.episode_start_indices[i]:self.episode_start_indices[i+1]], device=self.device) for i in indices]) + old_values_batch = pad_sequence([th.tensor(self.values[self.episode_start_indices[i]:self.episode_start_indices[i+1]], device=self.device) for i in indices]) + old_log_probs_batch = pad_sequence([th.tensor(self.log_probs[self.episode_start_indices[i]:self.episode_start_indices[i+1]], device=self.device) for i in indices]) + advantages_batch = pad_sequence([th.tensor(self.advantages[self.episode_start_indices[i]:self.episode_start_indices[i+1]], device=self.device) for i in indices]) + returns_batch = pad_sequence([th.tensor(self.returns[self.episode_start_indices[i]:self.episode_start_indices[i+1]], device=self.device) for i in indices]) + masks_batch = pad_sequence([th.ones_like(r) for r in th.swapaxes(returns_batch, 0, 1)]) yield RecurrentDictRolloutBufferSequenceSamples( - observations={key:th.swapaxes(obs_batch[key], 0, 1) for key in obs_batch}, - actions=th.swapaxes(actions_batch, 0, 1), - old_values=th.swapaxes(old_values_batch, 0, 1), - old_log_prob=th.swapaxes(old_log_probs_batch, 0, 1), - advantages=th.swapaxes(advantages_batch, 0, 1), - returns=th.swapaxes(returns_batch, 0, 1), - masks=th.swapaxes(masks_batch, 0, 1) + observations=obs_batch, + actions=actions_batch, + old_values=old_values_batch, + old_log_prob=old_log_probs_batch, + advantages=advantages_batch, + returns=returns_batch, + masks=masks_batch ) \ No newline at end of file From 18ace01f01957c8b61c07fe5a195200ac3b7c12b Mon Sep 17 00:00:00 2001 From: b-x-ubuntu Date: Sun, 8 Jan 2023 14:16:52 +0100 Subject: [PATCH 08/17] improved indexing to sample all sequences --- sb3_contrib/common/recurrent/buffers.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/sb3_contrib/common/recurrent/buffers.py b/sb3_contrib/common/recurrent/buffers.py index addb8110..89127f1c 100644 --- a/sb3_contrib/common/recurrent/buffers.py +++ b/sb3_contrib/common/recurrent/buffers.py @@ -438,8 +438,9 @@ def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentRolloutBuf self.episode_start_indices = np.where(self.episode_starts == 1)[0] self.generator_ready = True - random_indices = SubsetRandomSampler(range(len(self.episode_start_indices)-1)) # dropping the last one to make indexing the np arrays much simpler + random_indices = SubsetRandomSampler(range(len(self.episode_start_indices))) batch_sampler = BatchSampler(random_indices, batch_size, drop_last=True) # drop last batch to prevent extremely small batches causing spurious updates + self.episode_start_indices = np.concatenate([self.episode_start_indices, np.array([len(self.episode_start_indices)])]) # add a dummy index to make the code below simpler # yields batches of whole sequences, shape: (sequence_length, batch_size, data_length)) for indices in batch_sampler: @@ -514,8 +515,9 @@ def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentDictRollou self.episode_start_indices = np.where(self.episode_starts == 1)[0] self.generator_ready = True - random_indices = SubsetRandomSampler(range(len(self.episode_start_indices)-1)) # dropping the last one to make indexing the np arrays much simpler + random_indices = SubsetRandomSampler(range(len(self.episode_start_indices))) batch_sampler = BatchSampler(random_indices, batch_size, drop_last=True) # drop last batch to prevent extremely small batches causing spurious updates + self.episode_start_indices = np.concatenate([self.episode_start_indices, np.array([len(self.episode_start_indices)])]) # add a dummy index to make the code below simpler # yields batches of whole sequences, shape: (sequence_length, batch_size, data_length) for indices in batch_sampler: From de092ba7756715b3c368fa548afacc03aebee2e7 Mon Sep 17 00:00:00 2001 From: b-x-ubuntu Date: Wed, 1 Mar 2023 11:11:45 +0100 Subject: [PATCH 09/17] integrated whole sequence train function with existing --- sb3_contrib/common/recurrent/buffers.py | 4 +- sb3_contrib/common/recurrent/type_aliases.py | 4 +- sb3_contrib/ppo_recurrent/ppo_recurrent.py | 159 ++----------------- 3 files changed, 21 insertions(+), 146 deletions(-) diff --git a/sb3_contrib/common/recurrent/buffers.py b/sb3_contrib/common/recurrent/buffers.py index 89127f1c..c5a2bd46 100644 --- a/sb3_contrib/common/recurrent/buffers.py +++ b/sb3_contrib/common/recurrent/buffers.py @@ -459,7 +459,7 @@ def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentRolloutBuf old_log_prob=old_log_probs_batch, advantages=advantages_batch, returns=returns_batch, - masks=masks_batch + mask=masks_batch ) @@ -539,5 +539,5 @@ def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentDictRollou old_log_prob=old_log_probs_batch, advantages=advantages_batch, returns=returns_batch, - masks=masks_batch + mask=masks_batch ) \ No newline at end of file diff --git a/sb3_contrib/common/recurrent/type_aliases.py b/sb3_contrib/common/recurrent/type_aliases.py index caed9605..f573be83 100644 --- a/sb3_contrib/common/recurrent/type_aliases.py +++ b/sb3_contrib/common/recurrent/type_aliases.py @@ -40,7 +40,7 @@ class RecurrentRolloutBufferSequenceSamples(NamedTuple): old_log_prob: th.Tensor advantages: th.Tensor returns: th.Tensor - masks: th.Tensor + mask: th.Tensor class RecurrentDictRolloutBufferSequenceSamples(NamedTuple): @@ -50,4 +50,4 @@ class RecurrentDictRolloutBufferSequenceSamples(NamedTuple): old_log_prob: th.Tensor advantages: th.Tensor returns: th.Tensor - masks: th.Tensor \ No newline at end of file + mask: th.Tensor \ No newline at end of file diff --git a/sb3_contrib/ppo_recurrent/ppo_recurrent.py b/sb3_contrib/ppo_recurrent/ppo_recurrent.py index d541d37d..be774d7d 100644 --- a/sb3_contrib/ppo_recurrent/ppo_recurrent.py +++ b/sb3_contrib/ppo_recurrent/ppo_recurrent.py @@ -317,144 +317,10 @@ def collect_rollouts( return True - def train_whole_sequences(self) -> None: - """ - Update policy using the currently gathered rollout buffer but do it on 3d batches of whole sequences. - """ - # Switch to train mode (this affects batch norm / dropout) - self.policy.set_training_mode(True) - # Update optimizer learning rate - self._update_learning_rate(self.policy.optimizer) - # Compute current clip range - clip_range = self.clip_range(self._current_progress_remaining) - # Optional: clip range for the value function - if self.clip_range_vf is not None: - clip_range_vf = self.clip_range_vf(self._current_progress_remaining) - - entropy_losses = [] - pg_losses, value_losses = [], [] - clip_fractions = [] - - continue_training = True - - # train for n_epochs epochs - for epoch in range(self.n_epochs): - approx_kl_divs = [] - # Do a complete pass on the rollout buffer - for rollout_data in self.rollout_buffer.get(self.batch_size): - actions = rollout_data.actions - if isinstance(self.action_space, spaces.Discrete): - # Convert discrete action from float to long - actions = rollout_data.actions.long() - - # Convert mask from float to bool - mask = rollout_data.masks > 1e-8 - - # Re-sample the noise matrix because the log_std has changed - if self.use_sde: - self.policy.reset_noise(self.batch_size) - - values, log_prob, entropy = self.policy.evaluate_actions_whole_sequence( - rollout_data.observations, - actions, - ) - - # Normalize advantage - advantages = rollout_data.advantages - if self.normalize_advantage: - advantages = (advantages - advantages[mask].mean()) / (advantages[mask].std() + 1e-8) - - # ratio between old and new policy, should be one at the first iteration - ratio = th.exp(log_prob - rollout_data.old_log_prob) - - # clipped surrogate loss - policy_loss_1 = advantages * ratio - policy_loss_2 = advantages * th.clamp(ratio, 1 - clip_range, 1 + clip_range) - policy_loss = -th.mean(th.min(policy_loss_1, policy_loss_2)[mask]) - - # Logging - pg_losses.append(policy_loss.item()) - clip_fraction = th.mean((th.abs(ratio - 1) > clip_range).float()[mask]).item() - clip_fractions.append(clip_fraction) - - if self.clip_range_vf is None: - # No clipping - values_pred = values - else: - # Clip the different between old and new value - # NOTE: this depends on the reward scaling - values_pred = rollout_data.old_values + th.clamp( - values - rollout_data.old_values, -clip_range_vf, clip_range_vf - ) - # Value loss using the TD(gae_lambda) target - # Mask padded sequences - value_loss = th.mean(((rollout_data.returns - values_pred) ** 2)[mask]) - - value_losses.append(value_loss.item()) - - # Entropy loss favor exploration - if entropy is None: - # Approximate entropy when no analytical form - entropy_loss = -th.mean(-log_prob[mask]) - else: - entropy_loss = -th.mean(entropy[mask]) - - entropy_losses.append(entropy_loss.item()) - - loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss - - # Calculate approximate form of reverse KL Divergence for early stopping - # see issue #417: https://github.com/DLR-RM/stable-baselines3/issues/417 - # and discussion in PR #419: https://github.com/DLR-RM/stable-baselines3/pull/419 - # and Schulman blog: http://joschu.net/blog/kl-approx.html - with th.no_grad(): - log_ratio = log_prob - rollout_data.old_log_prob - approx_kl_div = th.mean(((th.exp(log_ratio) - 1) - log_ratio)[mask]).cpu().numpy() - approx_kl_divs.append(approx_kl_div) - - if self.target_kl is not None and approx_kl_div > 1.5 * self.target_kl: - continue_training = False - if self.verbose >= 1: - print(f"Early stopping at step {epoch} due to reaching max kl: {approx_kl_div:.2f}") - break - - # Optimization step - self.policy.optimizer.zero_grad() - loss.backward() - # Clip grad norm - th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm) - self.policy.optimizer.step() - - if not continue_training: - break - - self._n_updates += self.n_epochs - explained_var = explained_variance(self.rollout_buffer.values.flatten(), self.rollout_buffer.returns.flatten()) - - # Logs - self.logger.record("train/entropy_loss", np.mean(entropy_losses)) - self.logger.record("train/policy_gradient_loss", np.mean(pg_losses)) - self.logger.record("train/value_loss", np.mean(value_losses)) - self.logger.record("train/approx_kl", np.mean(approx_kl_divs)) - self.logger.record("train/clip_fraction", np.mean(clip_fractions)) - self.logger.record("train/loss", loss.item()) - self.logger.record("train/explained_variance", explained_var) - if hasattr(self.policy, "log_std"): - self.logger.record("train/std", th.exp(self.policy.log_std).mean().item()) - - self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard") - self.logger.record("train/clip_range", clip_range) - if self.clip_range_vf is not None: - self.logger.record("train/clip_range_vf", clip_range_vf) - def train(self) -> None: """ Update policy using the currently gathered rollout buffer. """ - if self.whole_sequences == True: - self.train_whole_sequences() - return - # Switch to train mode (this affects batch norm / dropout) self.policy.set_training_mode(True) # Update optimizer learning rate @@ -479,7 +345,10 @@ def train(self) -> None: actions = rollout_data.actions if isinstance(self.action_space, spaces.Discrete): # Convert discrete action from float to long - actions = rollout_data.actions.long().flatten() + if self.whole_sequences: + actions = rollout_data.actions.long() + else: + actions = rollout_data.actions.long().flatten() # Convert mask from float to bool mask = rollout_data.mask > 1e-8 @@ -488,14 +357,20 @@ def train(self) -> None: if self.use_sde: self.policy.reset_noise(self.batch_size) - values, log_prob, entropy = self.policy.evaluate_actions( - rollout_data.observations, - actions, - rollout_data.lstm_states, - rollout_data.episode_starts, - ) + if self.whole_sequences: + values, log_prob, entropy = self.policy.evaluate_actions_whole_sequence( + rollout_data.observations, + actions, + ) + else: + values, log_prob, entropy = self.policy.evaluate_actions( + rollout_data.observations, + actions, + rollout_data.lstm_states, + rollout_data.episode_starts, + ) + values = values.flatten() - values = values.flatten() # Normalize advantage advantages = rollout_data.advantages if self.normalize_advantage: From 5dc8bc9897b2af6b7f5f4a87e1ff36df72bba042 Mon Sep 17 00:00:00 2001 From: b-x-ubuntu Date: Wed, 1 Mar 2023 12:14:19 +0100 Subject: [PATCH 10/17] improvement of isntance checking --- sb3_contrib/common/recurrent/policies.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/sb3_contrib/common/recurrent/policies.py b/sb3_contrib/common/recurrent/policies.py index a3669e5f..33fc56a5 100644 --- a/sb3_contrib/common/recurrent/policies.py +++ b/sb3_contrib/common/recurrent/policies.py @@ -362,7 +362,10 @@ def evaluate_actions_whole_sequence( and entropy of the action distribution. """ # Preprocess the observation if needed - if self.features_extractor_class == FlattenExtractor: # temporary fix to disable the flattening that stable_baselines3 feature extractors do by default + + # temporary fix to disable the flattening that stable_baselines3 feature extractors do by default + # flattening will turn the sequences in the batch into 1 long sequence without proper resetting of lstm hidden states + if isinstance(self.features_extractor_class, FlattenExtractor): features = obs else: features = self.extract_features(obs) From e543b310d9f1a76ecd24dda60d66b3736a003653 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Mon, 3 Apr 2023 15:26:41 +0200 Subject: [PATCH 11/17] Reformat and simplify --- sb3_contrib/common/recurrent/buffers.py | 106 ++++++++++++------- sb3_contrib/common/recurrent/policies.py | 4 +- sb3_contrib/common/recurrent/type_aliases.py | 2 +- sb3_contrib/ppo_recurrent/ppo_recurrent.py | 19 +++- 4 files changed, 85 insertions(+), 46 deletions(-) diff --git a/sb3_contrib/common/recurrent/buffers.py b/sb3_contrib/common/recurrent/buffers.py index 0f645d65..46fe9eed 100644 --- a/sb3_contrib/common/recurrent/buffers.py +++ b/sb3_contrib/common/recurrent/buffers.py @@ -1,20 +1,20 @@ from functools import partial -from typing import Callable, Generator, Optional, Tuple, Union +from typing import Callable, Generator, List, Optional, Tuple, Union import numpy as np import torch as th -from torch.utils.data.sampler import BatchSampler, SubsetRandomSampler -from torch.nn.utils.rnn import pad_sequence from gym import spaces from stable_baselines3.common.buffers import DictRolloutBuffer, RolloutBuffer from stable_baselines3.common.vec_env import VecNormalize +from torch.nn.utils.rnn import pad_sequence +from torch.utils.data.sampler import BatchSampler, SubsetRandomSampler from sb3_contrib.common.recurrent.type_aliases import ( RecurrentDictRolloutBufferSamples, + RecurrentDictRolloutBufferSequenceSamples, RecurrentRolloutBufferSamples, - RNNStates, RecurrentRolloutBufferSequenceSamples, - RecurrentDictRolloutBufferSequenceSamples, + RNNStates, ) @@ -98,6 +98,30 @@ def create_sequencers( return seq_start_indices, local_pad, local_pad_and_flatten +def create_sequence_slicer( + episode_start_indices: np.ndarray, device: Union[th.device, str] +) -> Callable[[np.ndarray, List[str]], th.Tensor]: + def create_sequence_minibatch(tensor: np.ndarray, seq_indices: List[str]) -> th.Tensor: + """ + Create minibatch of whole sequence. + + :param tensor: Tensor that will be sliced (e.g. observations, rewards) + :param seq_indices: Sequences to be used. + :return: (max_sequence_length, batch_size=n_seq, features_size) + """ + return pad_sequence( + [ + th.tensor( + tensor[episode_start_indices[i] : episode_start_indices[i + 1]], + device=device, + ) + for i in seq_indices + ] + ) + + return create_sequence_minibatch + + class RecurrentRolloutBuffer(RolloutBuffer): """ Rollout buffer that also stores the LSTM cell and hidden states. @@ -417,7 +441,9 @@ def __init__( ): self.hidden_state_shape = hidden_state_shape self.seq_start_indices, self.seq_end_indices = None, None - super().__init__(buffer_size, observation_space, action_space, hidden_state_shape, device, gae_lambda, gamma, n_envs=n_envs) + super().__init__( + buffer_size, observation_space, action_space, hidden_state_shape, device, gae_lambda, gamma, n_envs=n_envs + ) def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentRolloutBufferSequenceSamples, None, None]: assert self.full, "Rollout buffer must be full before sampling from it" @@ -439,27 +465,26 @@ def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentRolloutBuf self.generator_ready = True random_indices = SubsetRandomSampler(range(len(self.episode_start_indices))) - batch_sampler = BatchSampler(random_indices, batch_size, drop_last=True) # drop last batch to prevent extremely small batches causing spurious updates - self.episode_start_indices = np.concatenate([self.episode_start_indices, np.array([len(self.episode_start_indices)])]) # add a dummy index to make the code below simpler + # drop last batch to prevent extremely small batches causing spurious updates + batch_sampler = BatchSampler(random_indices, batch_size, drop_last=True) + # add a dummy index to make the code below simpler + self.episode_start_indices = np.concatenate([self.episode_start_indices, np.array([len(self.episode_start_indices)])]) + + create_minibatch = create_sequence_slicer(self.episode_start_indices, self.device) - # yields batches of whole sequences, shape: (sequence_length, batch_size, data_length)) + # yields batches of whole sequences, shape: (max_sequence_length, batch_size=n_seq, features_size)) for indices in batch_sampler: - obs_batch = pad_sequence([th.tensor(self.observations[self.episode_start_indices[i]:self.episode_start_indices[i+1]], device=self.device) for i in indices]) - actions_batch = pad_sequence([th.tensor(self.actions[self.episode_start_indices[i]:self.episode_start_indices[i+1]], device=self.device) for i in indices]) - old_values_batch = pad_sequence([th.tensor(self.values[self.episode_start_indices[i]:self.episode_start_indices[i+1]], device=self.device) for i in indices]) - old_log_probs_batch = pad_sequence([th.tensor(self.log_probs[self.episode_start_indices[i]:self.episode_start_indices[i+1]], device=self.device) for i in indices]) - advantages_batch = pad_sequence([th.tensor(self.advantages[self.episode_start_indices[i]:self.episode_start_indices[i+1]], device=self.device) for i in indices]) - returns_batch = pad_sequence([th.tensor(self.returns[self.episode_start_indices[i]:self.episode_start_indices[i+1]], device=self.device) for i in indices]) - masks_batch = pad_sequence([th.ones_like(r) for r in th.swapaxes(returns_batch, 0, 1)]) + returns_batch = create_minibatch(self.returns, indices) + masks_batch = pad_sequence([th.ones_like(returns) for returns in th.swapaxes(returns_batch, 0, 1)]) yield RecurrentRolloutBufferSequenceSamples( - observations=obs_batch, - actions=actions_batch, - old_values=old_values_batch, - old_log_prob=old_log_probs_batch, - advantages=advantages_batch, + observations=create_minibatch(self.observations, indices), + actions=create_minibatch(self.actions, indices), + old_values=create_minibatch(self.values, indices), + old_log_prob=create_minibatch(self.log_probs, indices), + advantages=create_minibatch(self.advantages, indices), returns=returns_batch, - mask=masks_batch + mask=masks_batch, ) @@ -492,7 +517,9 @@ def __init__( ): self.hidden_state_shape = hidden_state_shape self.seq_start_indices, self.seq_end_indices = None, None - super().__init__(buffer_size, observation_space, action_space, hidden_state_shape, device, gae_lambda, gamma, n_envs=n_envs) + super().__init__( + buffer_size, observation_space, action_space, hidden_state_shape, device, gae_lambda, gamma, n_envs=n_envs + ) def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentDictRolloutBufferSequenceSamples, None, None]: assert self.full, "Rollout buffer must be full before sampling from it" @@ -516,28 +543,27 @@ def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentDictRollou self.generator_ready = True random_indices = SubsetRandomSampler(range(len(self.episode_start_indices))) - batch_sampler = BatchSampler(random_indices, batch_size, drop_last=True) # drop last batch to prevent extremely small batches causing spurious updates - self.episode_start_indices = np.concatenate([self.episode_start_indices, np.array([len(self.episode_start_indices)])]) # add a dummy index to make the code below simpler + # drop last batch to prevent extremely small batches causing spurious updates + batch_sampler = BatchSampler(random_indices, batch_size, drop_last=True) + # add a dummy index to make the code below simpler + self.episode_start_indices = np.concatenate([self.episode_start_indices, np.array([len(self.episode_start_indices)])]) + + create_minibatch = create_sequence_slicer(self.episode_start_indices, self.device) - # yields batches of whole sequences, shape: (sequence_length, batch_size, data_length) + # yields batches of whole sequences, shape: (sequence_length, batch_size=n_seq, features_size) for indices in batch_sampler: obs_batch = {} for key in self.observations: - obs_batch[key] = pad_sequence([th.tensor(self.observations[key][self.episode_start_indices[i]:self.episode_start_indices[i+1]], device=self.device) for i in indices]) - - actions_batch = pad_sequence([th.tensor(self.actions[self.episode_start_indices[i]:self.episode_start_indices[i+1]], device=self.device) for i in indices]) - old_values_batch = pad_sequence([th.tensor(self.values[self.episode_start_indices[i]:self.episode_start_indices[i+1]], device=self.device) for i in indices]) - old_log_probs_batch = pad_sequence([th.tensor(self.log_probs[self.episode_start_indices[i]:self.episode_start_indices[i+1]], device=self.device) for i in indices]) - advantages_batch = pad_sequence([th.tensor(self.advantages[self.episode_start_indices[i]:self.episode_start_indices[i+1]], device=self.device) for i in indices]) - returns_batch = pad_sequence([th.tensor(self.returns[self.episode_start_indices[i]:self.episode_start_indices[i+1]], device=self.device) for i in indices]) - masks_batch = pad_sequence([th.ones_like(r) for r in th.swapaxes(returns_batch, 0, 1)]) + obs_batch[key] = create_minibatch(self.observations[key], indices) + returns_batch = create_minibatch(self.returns, indices) + masks_batch = pad_sequence([th.ones_like(returns) for returns in th.swapaxes(returns_batch, 0, 1)]) yield RecurrentDictRolloutBufferSequenceSamples( observations=obs_batch, - actions=actions_batch, - old_values=old_values_batch, - old_log_prob=old_log_probs_batch, - advantages=advantages_batch, + actions=create_minibatch(self.actions, indices), + old_values=create_minibatch(self.values, indices), + old_log_prob=create_minibatch(self.log_probs, indices), + advantages=create_minibatch(self.advantages, indices), returns=returns_batch, - mask=masks_batch - ) \ No newline at end of file + mask=masks_batch, + ) diff --git a/sb3_contrib/common/recurrent/policies.py b/sb3_contrib/common/recurrent/policies.py index 33fc56a5..9a89a848 100644 --- a/sb3_contrib/common/recurrent/policies.py +++ b/sb3_contrib/common/recurrent/policies.py @@ -385,10 +385,10 @@ def evaluate_actions_whole_sequence( distribution = self._get_action_dist_from_latent(latent_pi) log_prob = distribution.distribution.log_prob(actions).sum(dim=-1) - log_prob = log_prob.reshape(log_prob.shape + (1,)) + log_prob = log_prob.reshape((*log_prob.shape, 1)) entropy = distribution.distribution.entropy().sum(dim=-1) - entropy = entropy.reshape(entropy.shape + (1,)) + entropy = entropy.reshape((*entropy.shape, 1)) return values, log_prob, entropy diff --git a/sb3_contrib/common/recurrent/type_aliases.py b/sb3_contrib/common/recurrent/type_aliases.py index f573be83..33fffd59 100644 --- a/sb3_contrib/common/recurrent/type_aliases.py +++ b/sb3_contrib/common/recurrent/type_aliases.py @@ -50,4 +50,4 @@ class RecurrentDictRolloutBufferSequenceSamples(NamedTuple): old_log_prob: th.Tensor advantages: th.Tensor returns: th.Tensor - mask: th.Tensor \ No newline at end of file + mask: th.Tensor diff --git a/sb3_contrib/ppo_recurrent/ppo_recurrent.py b/sb3_contrib/ppo_recurrent/ppo_recurrent.py index 75d8f0f6..d6dc7767 100644 --- a/sb3_contrib/ppo_recurrent/ppo_recurrent.py +++ b/sb3_contrib/ppo_recurrent/ppo_recurrent.py @@ -14,7 +14,12 @@ from stable_baselines3.common.utils import explained_variance, get_schedule_fn, obs_as_tensor, safe_mean from stable_baselines3.common.vec_env import VecEnv -from sb3_contrib.common.recurrent.buffers import RecurrentDictRolloutBuffer, RecurrentRolloutBuffer, RecurrentSequenceDictRolloutBuffer, RecurrentSequenceRolloutBuffer +from sb3_contrib.common.recurrent.buffers import ( + RecurrentDictRolloutBuffer, + RecurrentRolloutBuffer, + RecurrentSequenceDictRolloutBuffer, + RecurrentSequenceRolloutBuffer, +) from sb3_contrib.common.recurrent.policies import RecurrentActorCriticPolicy from sb3_contrib.common.recurrent.type_aliases import RNNStates from sb3_contrib.ppo_recurrent.policies import CnnLstmPolicy, MlpLstmPolicy, MultiInputLstmPolicy @@ -143,7 +148,9 @@ def _setup_model(self) -> None: # 3d batches of whole sequences or 2d batches of constant size if self.whole_sequences: buffer_cls = ( - RecurrentSequenceDictRolloutBuffer if isinstance(self.observation_space, spaces.Dict) else RecurrentSequenceRolloutBuffer + RecurrentSequenceDictRolloutBuffer + if isinstance(self.observation_space, spaces.Dict) + else RecurrentSequenceRolloutBuffer ) else: buffer_cls = ( @@ -221,7 +228,13 @@ def collect_rollouts( collected, False if callback terminated rollout prematurely. """ assert isinstance( - rollout_buffer, (RecurrentRolloutBuffer, RecurrentDictRolloutBuffer, RecurrentSequenceRolloutBuffer, RecurrentSequenceDictRolloutBuffer) + rollout_buffer, + ( + RecurrentRolloutBuffer, + RecurrentDictRolloutBuffer, + RecurrentSequenceRolloutBuffer, + RecurrentSequenceDictRolloutBuffer, + ), ), f"{rollout_buffer} doesn't support recurrent policy" assert self._last_obs is not None, "No previous observation was provided" From 194f75803905c831b1877c455d0301c587a73461 Mon Sep 17 00:00:00 2001 From: b-x-ubuntu Date: Wed, 12 Apr 2023 11:00:39 +0200 Subject: [PATCH 12/17] bug fix flatten extractor instance check --- sb3_contrib/common/recurrent/policies.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sb3_contrib/common/recurrent/policies.py b/sb3_contrib/common/recurrent/policies.py index 9a89a848..21472eb1 100644 --- a/sb3_contrib/common/recurrent/policies.py +++ b/sb3_contrib/common/recurrent/policies.py @@ -365,7 +365,7 @@ def evaluate_actions_whole_sequence( # temporary fix to disable the flattening that stable_baselines3 feature extractors do by default # flattening will turn the sequences in the batch into 1 long sequence without proper resetting of lstm hidden states - if isinstance(self.features_extractor_class, FlattenExtractor): + if self.features_extractor_class == FlattenExtractor: features = obs else: features = self.extract_features(obs) From 8f18c9c14fc1cb1e326c171330b08ae16c275433 Mon Sep 17 00:00:00 2001 From: b-x-ubuntu Date: Wed, 12 Apr 2023 11:18:11 +0200 Subject: [PATCH 13/17] simplified if statement --- sb3_contrib/ppo_recurrent/ppo_recurrent.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/sb3_contrib/ppo_recurrent/ppo_recurrent.py b/sb3_contrib/ppo_recurrent/ppo_recurrent.py index d6dc7767..8c4ae2eb 100644 --- a/sb3_contrib/ppo_recurrent/ppo_recurrent.py +++ b/sb3_contrib/ppo_recurrent/ppo_recurrent.py @@ -358,10 +358,9 @@ def train(self) -> None: actions = rollout_data.actions if isinstance(self.action_space, spaces.Discrete): # Convert discrete action from float to long - if self.whole_sequences: - actions = rollout_data.actions.long() - else: - actions = rollout_data.actions.long().flatten() + actions = rollout_data.actions.long() + if not self.whole_sequences: + actions = actions.flatten() # Convert mask from float to bool mask = rollout_data.mask > 1e-8 From ad0d3edec8ba4b43aee5697d9439a815db210449 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Thu, 27 Apr 2023 11:52:13 +0200 Subject: [PATCH 14/17] Update comment --- sb3_contrib/common/recurrent/buffers.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sb3_contrib/common/recurrent/buffers.py b/sb3_contrib/common/recurrent/buffers.py index 2e1a8676..3428d499 100644 --- a/sb3_contrib/common/recurrent/buffers.py +++ b/sb3_contrib/common/recurrent/buffers.py @@ -466,7 +466,8 @@ def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentRolloutBuf random_indices = SubsetRandomSampler(range(len(self.episode_start_indices))) # drop last batch to prevent extremely small batches causing spurious updates - batch_sampler = BatchSampler(random_indices, batch_size, drop_last=True) + # TODO: allow to change that parameter, otherwise nothing can be sampled + batch_sampler = BatchSampler(random_indices, batch_size, drop_last=False) # add a dummy index to make the code below simpler self.episode_start_indices = np.concatenate([self.episode_start_indices, np.array([len(self.episode_start_indices)])]) From b43e9b55b96eb550b2af1991fa367f9bbf967c77 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Thu, 27 Apr 2023 11:57:02 +0200 Subject: [PATCH 15/17] Re-add drop last, was causing NaN --- sb3_contrib/common/recurrent/buffers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sb3_contrib/common/recurrent/buffers.py b/sb3_contrib/common/recurrent/buffers.py index 3428d499..edfbbcaf 100644 --- a/sb3_contrib/common/recurrent/buffers.py +++ b/sb3_contrib/common/recurrent/buffers.py @@ -467,7 +467,7 @@ def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentRolloutBuf random_indices = SubsetRandomSampler(range(len(self.episode_start_indices))) # drop last batch to prevent extremely small batches causing spurious updates # TODO: allow to change that parameter, otherwise nothing can be sampled - batch_sampler = BatchSampler(random_indices, batch_size, drop_last=False) + batch_sampler = BatchSampler(random_indices, batch_size, drop_last=True) # add a dummy index to make the code below simpler self.episode_start_indices = np.concatenate([self.episode_start_indices, np.array([len(self.episode_start_indices)])]) From ef37cc79ada4c2b0ee957b1eb1a98a3f608d1b26 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Thu, 27 Apr 2023 12:00:59 +0200 Subject: [PATCH 16/17] Fix NaN --- sb3_contrib/common/recurrent/buffers.py | 6 +++--- sb3_contrib/ppo_recurrent/ppo_recurrent.py | 3 ++- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/sb3_contrib/common/recurrent/buffers.py b/sb3_contrib/common/recurrent/buffers.py index edfbbcaf..b2ee584c 100644 --- a/sb3_contrib/common/recurrent/buffers.py +++ b/sb3_contrib/common/recurrent/buffers.py @@ -465,9 +465,9 @@ def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentRolloutBuf self.generator_ready = True random_indices = SubsetRandomSampler(range(len(self.episode_start_indices))) - # drop last batch to prevent extremely small batches causing spurious updates - # TODO: allow to change that parameter, otherwise nothing can be sampled - batch_sampler = BatchSampler(random_indices, batch_size, drop_last=True) + # Do not drop last batch so we are sure we sample at least one sequence + # TODO: allow to change that parameter + batch_sampler = BatchSampler(random_indices, batch_size, drop_last=False) # add a dummy index to make the code below simpler self.episode_start_indices = np.concatenate([self.episode_start_indices, np.array([len(self.episode_start_indices)])]) diff --git a/sb3_contrib/ppo_recurrent/ppo_recurrent.py b/sb3_contrib/ppo_recurrent/ppo_recurrent.py index e3912b80..9d0e980d 100644 --- a/sb3_contrib/ppo_recurrent/ppo_recurrent.py +++ b/sb3_contrib/ppo_recurrent/ppo_recurrent.py @@ -389,7 +389,8 @@ def train(self) -> None: # Normalize advantage advantages = rollout_data.advantages - if self.normalize_advantage: + # Normalization does not make sense if mini batchsize == 1, see GH issue #325 + if self.normalize_advantage and len(advantages) > 1: advantages = (advantages - advantages[mask].mean()) / (advantages[mask].std() + 1e-8) # ratio between old and new policy, should be one at the first iteration From a2382c8341c65da779fd488243d0f1a40b0ae623 Mon Sep 17 00:00:00 2001 From: b-vm Date: Mon, 29 May 2023 18:46:25 +0200 Subject: [PATCH 17/17] fixed bug: append correct index, and only once. --- sb3_contrib/common/recurrent/buffers.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sb3_contrib/common/recurrent/buffers.py b/sb3_contrib/common/recurrent/buffers.py index b2ee584c..ccef3553 100644 --- a/sb3_contrib/common/recurrent/buffers.py +++ b/sb3_contrib/common/recurrent/buffers.py @@ -469,9 +469,9 @@ def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentRolloutBuf # TODO: allow to change that parameter batch_sampler = BatchSampler(random_indices, batch_size, drop_last=False) # add a dummy index to make the code below simpler - self.episode_start_indices = np.concatenate([self.episode_start_indices, np.array([len(self.episode_start_indices)])]) + episode_start_indices = np.concatenate([self.episode_start_indices, np.array([len(self.episode_starts)])]) - create_minibatch = create_sequence_slicer(self.episode_start_indices, self.device) + create_minibatch = create_sequence_slicer(episode_start_indices, self.device) # yields batches of whole sequences, shape: (max_sequence_length, batch_size=n_seq, features_size)) for indices in batch_sampler: @@ -547,9 +547,9 @@ def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentDictRollou # drop last batch to prevent extremely small batches causing spurious updates batch_sampler = BatchSampler(random_indices, batch_size, drop_last=True) # add a dummy index to make the code below simpler - self.episode_start_indices = np.concatenate([self.episode_start_indices, np.array([len(self.episode_start_indices)])]) + episode_start_indices = np.concatenate([self.episode_start_indices, np.array([len(self.episode_starts)])]) - create_minibatch = create_sequence_slicer(self.episode_start_indices, self.device) + create_minibatch = create_sequence_slicer(episode_start_indices, self.device) # yields batches of whole sequences, shape: (sequence_length, batch_size=n_seq, features_size) for indices in batch_sampler: