-
Notifications
You must be signed in to change notification settings - Fork 725
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
adding LSTM support to pretrain #315
base: master
Are you sure you want to change the base?
Conversation
add frame index aliment to ExpertDataset
update forke
# Conflicts: # stable_baselines/gail/dataset/dataset.py # stable_baselines/trpo_mpi/trpo_mpi.py # tests/test_gail.py
The problem is that |
-convert float envs_per_batch in to int envs_per_batch
Hello, do you consider this PR ready for review? (After a quick look, I saw that a saved file was still there (nano) and there seems to be some code duplication that can be improved ;)) |
o/
I removed nano.
I am not quite sure but I think you are referring to the The only thing I still plan to do is to add a bit more functionality to it. The code which is already there is so far finalized and can be reviewed. |
The Test has failed after updating my branch. |
You can ignore this, I've attempted a fix in #467 |
Now that my PR has finally pass all the unit test again, could you start reviewing the PR, so that I then can change it if necessary. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please remove the test_recorded_images
folder too.
stable_baselines/a2c/a2c.py
Outdated
@@ -87,9 +87,20 @@ def __init__(self, policy, env, gamma=0.99, n_steps=5, vf_coef=0.25, ent_coef=0. | |||
|
|||
def _get_pretrain_placeholders(self): | |||
policy = self.train_model | |||
|
|||
if self.initial_state is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can do:
states_ph, snew_ph, dones_ph = None, None, None
so it's more compact, same for the else case
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think having the variable declaration vertical and horizontal interrupts the read flow. Yes It would make the code shorter but also les readable in my opinion. But I will change it if you really wont it that way.
stable_baselines/acer/acer_simple.py
Outdated
@@ -152,8 +152,18 @@ def __init__(self, policy, env, gamma=0.99, n_steps=20, num_procs=1, q_coef=0.5, | |||
def _get_pretrain_placeholders(self): | |||
policy = self.step_model | |||
action_ph = policy.pdtype.sample_placeholder([None]) | |||
|
|||
if self.initial_state is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same remark as before
stable_baselines/a2c/a2c.py
Outdated
@@ -87,9 +87,20 @@ def __init__(self, policy, env, gamma=0.99, n_steps=5, vf_coef=0.25, ent_coef=0. | |||
|
|||
def _get_pretrain_placeholders(self): | |||
policy = self.train_model | |||
|
|||
if self.initial_state is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you should rather check the recurrent
attribute of the policy, it is in the base policy class
@@ -50,6 +50,10 @@ def __init__(self, policy, env, verbose=0, *, requires_vec_env, policy_base, pol | |||
self.sess = None | |||
self.params = None | |||
self._param_load_ops = None | |||
self.initial_state = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you don't need that variable, there is the recurrent attribute for that
@@ -246,13 +250,24 @@ def pretrain(self, dataset, n_epochs=10, learning_rate=1e-4, | |||
else: | |||
val_interval = int(n_epochs / 10) | |||
|
|||
use_lstm = self.initial_state is not None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same remark, you can use the recurrent attribute
@@ -272,13 +287,23 @@ def pretrain(self, dataset, n_epochs=10, learning_rate=1e-4, | |||
|
|||
for epoch_idx in range(int(n_epochs)): | |||
train_loss = 0.0 | |||
if use_lstm: | |||
state = self.initial_state[:envs_per_batch] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
initial state is an attribute of the policy
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes and no.
All the models Which can use LSTM policies have the Variable self.initial_state
, Which gets set to the initial state from policy. The variable self.initial_state
gets used and not the one in the policy. It is also not that easy, to access the initial state from the BaseRLModel
. It wars much simpler to at the self.initial_state
variable to the Base Model, and then let is overwrite later at model initialization.
|
||
if use_lstm: | ||
feed_dict.update({states_ph: state, dones_ph: expert_mask}) | ||
val_loss_, = self.sess.run([loss], feed_dict) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you only need to update the feeddict, self.sess.run
can be called outside, so you avoid code duplication
:param batch_size: (int) the minibatch size for behavior cloning | ||
:param traj_limitation: (int) the number of trajectory to use (if -1, load all) | ||
:param randomize: (bool) if the dataset should be shuffled | ||
:param randomize: (bool) if the dataset should be shuffled, this will be overwritten to False | ||
if LSTM is True. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
? where is the LSTM
variable?
except StopIteration: | ||
dataloader = iter(dataloader) | ||
return next(dataloader) | ||
if traj_data is not None and expert_path is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can't this be check in the base class?
Looks like duplicated code
Also, I'm not sure if two classes are needed...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I originally had it in one class, but someone who used the PR, has suggested to split it in two classes. I think that this was a good idea, because it clearly improved the user friendliness of the ExpertDataset
class.
tests/test_gail.py
Outdated
model.pretrain(dataset, n_epochs=20) | ||
model.save("test-pretrain") | ||
del dataset, model | ||
@pytest.mark.parametrize("model_class_data", [[A2C, 4, True, "MlpLstmPolicy", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like duplicated code, I think you can handle the different type of policy and expert path in the if.
…aselines into LSTM-pretrain # Conflicts: # stable_baselines/gail/dataset/dataset.py
Hi everyone, I would like to know whether there is still active work on "LSTM support to pretrain". I've seen that this feature was removed from the v2.8.0 milestone more than a month ago. Is the work on hold? Kind regards! |
@skervim |
hello @skervim, As @XMaster96 says, you can use his fork for now if you want try the feature. |
Are you referring to the requested changes, I have partially implemented and partially commented on why I think I shouldn't change that. Or are you referring to future change requests? I am also aware that I don't have yet written a Documentation for the website, I was planning on doing it when everything is ok, and merge ready. |
This PR adds LSTM support to
pretrain
. I am not quite done yet, but there are some Implementations matters that I need to discuss first.personal edit:
I finally found the time to work more on this PR. The problem is that I took so long that I forgot half of what I did, so if there is some rough code in there, it is because of that. I still do not have that much time, so expect me to not answer immediately
closes #253