From a9735b9f317be4283e56d221e19087b926ca9ec0 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Wed, 26 Oct 2022 18:03:45 +0200 Subject: [PATCH] Fix reshape LSTM states (#112) * Fix LSTM states reshape * Fix warnings and update changelog * Remove unused variable * Fix runtime error when using n_lstm_layers > 1 --- docs/misc/changelog.rst | 4 ++- sb3_contrib/common/recurrent/buffers.py | 36 +++++++++++----------- sb3_contrib/ppo_recurrent/ppo_recurrent.py | 4 +-- sb3_contrib/version.txt | 2 +- tests/test_lstm.py | 6 ++++ 5 files changed, 30 insertions(+), 22 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 5cd2d732..f6bddc78 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,7 +3,7 @@ Changelog ========== -Release 1.7.0a0 (WIP) +Release 1.7.0a1 (WIP) -------------------------- Breaking Changes: @@ -17,6 +17,8 @@ New Features: Bug Fixes: ^^^^^^^^^^ +- Fixed a bug in ``RecurrentPPO`` where the lstm states where incorrectly reshaped for ``n_lstm_layers > 1`` (thanks @kolbytn) +- Fixed ``RuntimeError: rnn: hx is not contiguous`` while predicting terminal values for ``RecurrentPPO`` when ``n_lstm_layers > 1`` Deprecations: ^^^^^^^^^^^^^ diff --git a/sb3_contrib/common/recurrent/buffers.py b/sb3_contrib/common/recurrent/buffers.py index 53856681..23d487f1 100644 --- a/sb3_contrib/common/recurrent/buffers.py +++ b/sb3_contrib/common/recurrent/buffers.py @@ -206,7 +206,6 @@ def _get_samples( self.episode_starts[batch_inds], env_change[batch_inds], self.device ) - n_layers = self.hidden_states_pi.shape[1] # Number of sequences n_seq = len(self.seq_start_indices) max_length = self.pad(self.actions[batch_inds]).shape[1] @@ -214,17 +213,19 @@ def _get_samples( # We retrieve the lstm hidden states that will allow # to properly initialize the LSTM at the beginning of each sequence lstm_states_pi = ( - # (n_steps, n_layers, n_envs, dim) -> (n_layers, n_seq, dim) - self.hidden_states_pi[batch_inds][self.seq_start_indices].reshape(n_layers, n_seq, -1), - self.cell_states_pi[batch_inds][self.seq_start_indices].reshape(n_layers, n_seq, -1), + # 1. (n_envs * n_steps, n_layers, dim) -> (batch_size, n_layers, dim) + # 2. (batch_size, n_layers, dim) -> (n_seq, n_layers, dim) + # 3. (n_seq, n_layers, dim) -> (n_layers, n_seq, dim) + self.hidden_states_pi[batch_inds][self.seq_start_indices].swapaxes(0, 1), + self.cell_states_pi[batch_inds][self.seq_start_indices].swapaxes(0, 1), ) lstm_states_vf = ( - # (n_steps, n_layers, n_envs, dim) -> (n_layers, n_seq, dim) - self.hidden_states_vf[batch_inds][self.seq_start_indices].reshape(n_layers, n_seq, -1), - self.cell_states_vf[batch_inds][self.seq_start_indices].reshape(n_layers, n_seq, -1), + # (n_envs * n_steps, n_layers, dim) -> (n_layers, n_seq, dim) + self.hidden_states_vf[batch_inds][self.seq_start_indices].swapaxes(0, 1), + self.cell_states_vf[batch_inds][self.seq_start_indices].swapaxes(0, 1), ) - lstm_states_pi = (self.to_torch(lstm_states_pi[0]), self.to_torch(lstm_states_pi[1])) - lstm_states_vf = (self.to_torch(lstm_states_vf[0]), self.to_torch(lstm_states_vf[1])) + lstm_states_pi = (self.to_torch(lstm_states_pi[0]).contiguous(), self.to_torch(lstm_states_pi[1]).contiguous()) + lstm_states_vf = (self.to_torch(lstm_states_vf[0]).contiguous(), self.to_torch(lstm_states_vf[1]).contiguous()) return RecurrentRolloutBufferSamples( # (batch_size, obs_dim) -> (n_seq, max_length, obs_dim) -> (n_seq * max_length, obs_dim) @@ -349,24 +350,23 @@ def _get_samples( self.episode_starts[batch_inds], env_change[batch_inds], self.device ) - n_layers = self.hidden_states_pi.shape[1] n_seq = len(self.seq_start_indices) max_length = self.pad(self.actions[batch_inds]).shape[1] padded_batch_size = n_seq * max_length # We retrieve the lstm hidden states that will allow # to properly initialize the LSTM at the beginning of each sequence lstm_states_pi = ( - # (n_steps, n_layers, n_envs, dim) -> (n_layers, n_seq, dim) - self.hidden_states_pi[batch_inds][self.seq_start_indices].reshape(n_layers, n_seq, -1), - self.cell_states_pi[batch_inds][self.seq_start_indices].reshape(n_layers, n_seq, -1), + # (n_envs * n_steps, n_layers, dim) -> (n_layers, n_seq, dim) + self.hidden_states_pi[batch_inds][self.seq_start_indices].swapaxes(0, 1), + self.cell_states_pi[batch_inds][self.seq_start_indices].swapaxes(0, 1), ) lstm_states_vf = ( - # (n_steps, n_layers, n_envs, dim) -> (n_layers, n_seq, dim) - self.hidden_states_vf[batch_inds][self.seq_start_indices].reshape(n_layers, n_seq, -1), - self.cell_states_vf[batch_inds][self.seq_start_indices].reshape(n_layers, n_seq, -1), + # (n_envs * n_steps, n_layers, dim) -> (n_layers, n_seq, dim) + self.hidden_states_vf[batch_inds][self.seq_start_indices].swapaxes(0, 1), + self.cell_states_vf[batch_inds][self.seq_start_indices].swapaxes(0, 1), ) - lstm_states_pi = (self.to_torch(lstm_states_pi[0]), self.to_torch(lstm_states_pi[1])) - lstm_states_vf = (self.to_torch(lstm_states_vf[0]), self.to_torch(lstm_states_vf[1])) + lstm_states_pi = (self.to_torch(lstm_states_pi[0]).contiguous(), self.to_torch(lstm_states_pi[1]).contiguous()) + lstm_states_vf = (self.to_torch(lstm_states_vf[0]).contiguous(), self.to_torch(lstm_states_vf[1]).contiguous()) observations = {key: self.pad(obs[batch_inds]) for (key, obs) in self.observations.items()} observations = {key: obs.reshape((padded_batch_size,) + self.obs_shape[key]) for (key, obs) in observations.items()} diff --git a/sb3_contrib/ppo_recurrent/ppo_recurrent.py b/sb3_contrib/ppo_recurrent/ppo_recurrent.py index da355163..7ee0f556 100644 --- a/sb3_contrib/ppo_recurrent/ppo_recurrent.py +++ b/sb3_contrib/ppo_recurrent/ppo_recurrent.py @@ -277,8 +277,8 @@ def collect_rollouts( terminal_obs = self.policy.obs_to_tensor(infos[idx]["terminal_observation"])[0] with th.no_grad(): terminal_lstm_state = ( - lstm_states.vf[0][:, idx : idx + 1, :], - lstm_states.vf[1][:, idx : idx + 1, :], + lstm_states.vf[0][:, idx : idx + 1, :].contiguous(), + lstm_states.vf[1][:, idx : idx + 1, :].contiguous(), ) # terminal_lstm_state = None episode_starts = th.tensor([False]).float().to(self.device) diff --git a/sb3_contrib/version.txt b/sb3_contrib/version.txt index 56fee069..12cd5fb3 100644 --- a/sb3_contrib/version.txt +++ b/sb3_contrib/version.txt @@ -1 +1 @@ -1.7.0a0 +1.7.0a1 diff --git a/tests/test_lstm.py b/tests/test_lstm.py index 29f5ef24..1da67796 100644 --- a/tests/test_lstm.py +++ b/tests/test_lstm.py @@ -66,11 +66,13 @@ def step(self, action): enable_critic_lstm=True, lstm_hidden_size=4, lstm_kwargs=dict(dropout=0.5), + n_lstm_layers=2, ), dict( enable_critic_lstm=False, lstm_hidden_size=4, lstm_kwargs=dict(dropout=0.5), + n_lstm_layers=2, ), ], ) @@ -95,11 +97,13 @@ def test_cnn(policy_kwargs): enable_critic_lstm=True, lstm_hidden_size=4, lstm_kwargs=dict(dropout=0.5), + n_lstm_layers=2, ), dict( enable_critic_lstm=False, lstm_hidden_size=4, lstm_kwargs=dict(dropout=0.5), + n_lstm_layers=2, ), ], ) @@ -162,11 +166,13 @@ def test_run_sde(): enable_critic_lstm=True, lstm_hidden_size=4, lstm_kwargs=dict(dropout=0.5), + n_lstm_layers=2, ), dict( enable_critic_lstm=False, lstm_hidden_size=4, lstm_kwargs=dict(dropout=0.5), + n_lstm_layers=2, ), ], )