Skip to content

Commit

Permalink
Fix reshape LSTM states (#112)
Browse files Browse the repository at this point in the history
* Fix LSTM states reshape

* Fix warnings and update changelog

* Remove unused variable

* Fix runtime error when using n_lstm_layers > 1
  • Loading branch information
araffin authored Oct 26, 2022
1 parent c75ad7d commit a9735b9
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 22 deletions.
4 changes: 3 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Changelog
==========

Release 1.7.0a0 (WIP)
Release 1.7.0a1 (WIP)
--------------------------

Breaking Changes:
Expand All @@ -17,6 +17,8 @@ New Features:

Bug Fixes:
^^^^^^^^^^
- Fixed a bug in ``RecurrentPPO`` where the lstm states where incorrectly reshaped for ``n_lstm_layers > 1`` (thanks @kolbytn)
- Fixed ``RuntimeError: rnn: hx is not contiguous`` while predicting terminal values for ``RecurrentPPO`` when ``n_lstm_layers > 1``

Deprecations:
^^^^^^^^^^^^^
Expand Down
36 changes: 18 additions & 18 deletions sb3_contrib/common/recurrent/buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,25 +206,26 @@ def _get_samples(
self.episode_starts[batch_inds], env_change[batch_inds], self.device
)

n_layers = self.hidden_states_pi.shape[1]
# Number of sequences
n_seq = len(self.seq_start_indices)
max_length = self.pad(self.actions[batch_inds]).shape[1]
padded_batch_size = n_seq * max_length
# We retrieve the lstm hidden states that will allow
# to properly initialize the LSTM at the beginning of each sequence
lstm_states_pi = (
# (n_steps, n_layers, n_envs, dim) -> (n_layers, n_seq, dim)
self.hidden_states_pi[batch_inds][self.seq_start_indices].reshape(n_layers, n_seq, -1),
self.cell_states_pi[batch_inds][self.seq_start_indices].reshape(n_layers, n_seq, -1),
# 1. (n_envs * n_steps, n_layers, dim) -> (batch_size, n_layers, dim)
# 2. (batch_size, n_layers, dim) -> (n_seq, n_layers, dim)
# 3. (n_seq, n_layers, dim) -> (n_layers, n_seq, dim)
self.hidden_states_pi[batch_inds][self.seq_start_indices].swapaxes(0, 1),
self.cell_states_pi[batch_inds][self.seq_start_indices].swapaxes(0, 1),
)
lstm_states_vf = (
# (n_steps, n_layers, n_envs, dim) -> (n_layers, n_seq, dim)
self.hidden_states_vf[batch_inds][self.seq_start_indices].reshape(n_layers, n_seq, -1),
self.cell_states_vf[batch_inds][self.seq_start_indices].reshape(n_layers, n_seq, -1),
# (n_envs * n_steps, n_layers, dim) -> (n_layers, n_seq, dim)
self.hidden_states_vf[batch_inds][self.seq_start_indices].swapaxes(0, 1),
self.cell_states_vf[batch_inds][self.seq_start_indices].swapaxes(0, 1),
)
lstm_states_pi = (self.to_torch(lstm_states_pi[0]), self.to_torch(lstm_states_pi[1]))
lstm_states_vf = (self.to_torch(lstm_states_vf[0]), self.to_torch(lstm_states_vf[1]))
lstm_states_pi = (self.to_torch(lstm_states_pi[0]).contiguous(), self.to_torch(lstm_states_pi[1]).contiguous())
lstm_states_vf = (self.to_torch(lstm_states_vf[0]).contiguous(), self.to_torch(lstm_states_vf[1]).contiguous())

return RecurrentRolloutBufferSamples(
# (batch_size, obs_dim) -> (n_seq, max_length, obs_dim) -> (n_seq * max_length, obs_dim)
Expand Down Expand Up @@ -349,24 +350,23 @@ def _get_samples(
self.episode_starts[batch_inds], env_change[batch_inds], self.device
)

n_layers = self.hidden_states_pi.shape[1]
n_seq = len(self.seq_start_indices)
max_length = self.pad(self.actions[batch_inds]).shape[1]
padded_batch_size = n_seq * max_length
# We retrieve the lstm hidden states that will allow
# to properly initialize the LSTM at the beginning of each sequence
lstm_states_pi = (
# (n_steps, n_layers, n_envs, dim) -> (n_layers, n_seq, dim)
self.hidden_states_pi[batch_inds][self.seq_start_indices].reshape(n_layers, n_seq, -1),
self.cell_states_pi[batch_inds][self.seq_start_indices].reshape(n_layers, n_seq, -1),
# (n_envs * n_steps, n_layers, dim) -> (n_layers, n_seq, dim)
self.hidden_states_pi[batch_inds][self.seq_start_indices].swapaxes(0, 1),
self.cell_states_pi[batch_inds][self.seq_start_indices].swapaxes(0, 1),
)
lstm_states_vf = (
# (n_steps, n_layers, n_envs, dim) -> (n_layers, n_seq, dim)
self.hidden_states_vf[batch_inds][self.seq_start_indices].reshape(n_layers, n_seq, -1),
self.cell_states_vf[batch_inds][self.seq_start_indices].reshape(n_layers, n_seq, -1),
# (n_envs * n_steps, n_layers, dim) -> (n_layers, n_seq, dim)
self.hidden_states_vf[batch_inds][self.seq_start_indices].swapaxes(0, 1),
self.cell_states_vf[batch_inds][self.seq_start_indices].swapaxes(0, 1),
)
lstm_states_pi = (self.to_torch(lstm_states_pi[0]), self.to_torch(lstm_states_pi[1]))
lstm_states_vf = (self.to_torch(lstm_states_vf[0]), self.to_torch(lstm_states_vf[1]))
lstm_states_pi = (self.to_torch(lstm_states_pi[0]).contiguous(), self.to_torch(lstm_states_pi[1]).contiguous())
lstm_states_vf = (self.to_torch(lstm_states_vf[0]).contiguous(), self.to_torch(lstm_states_vf[1]).contiguous())

observations = {key: self.pad(obs[batch_inds]) for (key, obs) in self.observations.items()}
observations = {key: obs.reshape((padded_batch_size,) + self.obs_shape[key]) for (key, obs) in observations.items()}
Expand Down
4 changes: 2 additions & 2 deletions sb3_contrib/ppo_recurrent/ppo_recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,8 +277,8 @@ def collect_rollouts(
terminal_obs = self.policy.obs_to_tensor(infos[idx]["terminal_observation"])[0]
with th.no_grad():
terminal_lstm_state = (
lstm_states.vf[0][:, idx : idx + 1, :],
lstm_states.vf[1][:, idx : idx + 1, :],
lstm_states.vf[0][:, idx : idx + 1, :].contiguous(),
lstm_states.vf[1][:, idx : idx + 1, :].contiguous(),
)
# terminal_lstm_state = None
episode_starts = th.tensor([False]).float().to(self.device)
Expand Down
2 changes: 1 addition & 1 deletion sb3_contrib/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1.7.0a0
1.7.0a1
6 changes: 6 additions & 0 deletions tests/test_lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,13 @@ def step(self, action):
enable_critic_lstm=True,
lstm_hidden_size=4,
lstm_kwargs=dict(dropout=0.5),
n_lstm_layers=2,
),
dict(
enable_critic_lstm=False,
lstm_hidden_size=4,
lstm_kwargs=dict(dropout=0.5),
n_lstm_layers=2,
),
],
)
Expand All @@ -95,11 +97,13 @@ def test_cnn(policy_kwargs):
enable_critic_lstm=True,
lstm_hidden_size=4,
lstm_kwargs=dict(dropout=0.5),
n_lstm_layers=2,
),
dict(
enable_critic_lstm=False,
lstm_hidden_size=4,
lstm_kwargs=dict(dropout=0.5),
n_lstm_layers=2,
),
],
)
Expand Down Expand Up @@ -162,11 +166,13 @@ def test_run_sde():
enable_critic_lstm=True,
lstm_hidden_size=4,
lstm_kwargs=dict(dropout=0.5),
n_lstm_layers=2,
),
dict(
enable_critic_lstm=False,
lstm_hidden_size=4,
lstm_kwargs=dict(dropout=0.5),
n_lstm_layers=2,
),
],
)
Expand Down

0 comments on commit a9735b9

Please sign in to comment.