Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RecurrentActorCriticPolicy Behaviour Not Clear #246

Open
2 tasks done
pasinit opened this issue May 9, 2024 · 1 comment
Open
2 tasks done

RecurrentActorCriticPolicy Behaviour Not Clear #246

pasinit opened this issue May 9, 2024 · 1 comment
Labels
documentation Improvements or additions to documentation

Comments

@pasinit
Copy link

pasinit commented May 9, 2024

📚 Documentation

I am trying to understand how the RecurrentActorCriticPolicy works. Coming from an NLP background I am used to have tensors of the shape (batch_size, seq_len, feature_dim) as input to the LSTM (and optional starting hidden states). From what I am seeing, however, the LSTM implemented basically allows only to feed sequence of length 1

for features, episode_start in zip_strict(features_sequence, episode_starts):

In fact, by zipping features_sequence (with shape [seq_len, n_envs, feature_dims]) and episode_starts (with shape [n_envs, -1]), in the case of 1 environment, we only allow seq_len to be 1.

Is this intended and am I reading this correctly? Is the logic behind that since we keep propagating the state we are still happy with sequences of length 1?

Checklist

@pasinit pasinit added the documentation Improvements or additions to documentation label May 9, 2024
@araffin
Copy link
Member

araffin commented May 10, 2024

tensors of the shape (batch_size, seq_len, feature_dim) as input to the LSTM (

that's correct

I think you missed:

# If we don't have to reset the state in the middle of a sequence
# we can avoid the for loop, which speeds up things
if th.all(episode_starts == 0.0):
lstm_output, lstm_states = lstm(features_sequence, lstm_states)
lstm_output = th.flatten(lstm_output.transpose(0, 1), start_dim=0, end_dim=1)
return lstm_output, lstm_states

here we pass a full sequence as input.

and for the rest, we unroll the sequence manually because we need to reset the state of the lstm when a new episode starts:

# Reset the states at the beginning of a new episode
(1.0 - episode_start).view(1, n_seq, 1) * lstm_states[0],
(1.0 - episode_start).view(1, n_seq, 1) * lstm_states[1],

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation
Projects
None yet
Development

No branches or pull requests

2 participants