From c6397743db5ddd9eab5fb4597fd7684065d9f62e Mon Sep 17 00:00:00 2001 From: William Guss Date: Tue, 16 Jul 2019 10:34:16 -0600 Subject: [PATCH 1/5] Adding multi-discrete pretraining --- stable_baselines/common/base_class.py | 29 ++++++++++++++++++++++++--- 1 file changed, 26 insertions(+), 3 deletions(-) diff --git a/stable_baselines/common/base_class.py b/stable_baselines/common/base_class.py index a0637d8d41..9cad0eb183 100644 --- a/stable_baselines/common/base_class.py +++ b/stable_baselines/common/base_class.py @@ -230,8 +230,8 @@ def pretrain(self, dataset, n_epochs=10, learning_rate=1e-4, """ continuous_actions = isinstance(self.action_space, gym.spaces.Box) discrete_actions = isinstance(self.action_space, gym.spaces.Discrete) - - assert discrete_actions or continuous_actions, 'Only Discrete and Box action spaces are supported' + multidiscrete_actions = isinstance(self.action_space, gym.spaces.MultiDiscrete) + assert discrete_actions or continuous_actions or multidiscrete_actions, 'Only Discrete, Box, or MultiDiscrete action spaces are supported' # Validate the model every 10% of the total number of iteration if val_interval is None: @@ -246,7 +246,7 @@ def pretrain(self, dataset, n_epochs=10, learning_rate=1e-4, if continuous_actions: obs_ph, actions_ph, deterministic_actions_ph = self._get_pretrain_placeholders() loss = tf.reduce_mean(tf.square(actions_ph - deterministic_actions_ph)) - else: + elif discrete_actions: obs_ph, actions_ph, actions_logits_ph = self._get_pretrain_placeholders() # actions_ph has a shape if (n_batch,), we reshape it to (n_batch, 1) # so no additional changes is needed in the dataloader @@ -257,11 +257,34 @@ def pretrain(self, dataset, n_epochs=10, learning_rate=1e-4, labels=tf.stop_gradient(one_hot_actions) ) loss = tf.reduce_mean(loss) + elif multidiscrete_actions: + losses = [] + obs_ph, actions_ph, logits = _get_pretrain_placeholders(self) + + n_actions = len(self.action_space.nvec) + action_indices = [0] + np.cumsum(self.action_space.nvec).tolist() + action_phs = [tf.one_hot(actions_ph[:,i], depth=self.action_space.nvec[i]) for i in range(n_actions)] + + action_logits_phs = [ + logits[:, action_indices[i]:action_indices[i+1]] for i in range(n_actions) + ] + + for one_hot_actions, action_logits_ph in zip(action_phs, action_logits_phs): + print(one_hot_actions, action_logits_ph) + loss_for_subspace = tf.nn.softmax_cross_entropy_with_logits_v2( + logits=action_logits_ph, + labels=tf.stop_gradient(one_hot_actions) + ) + loss_for_subspace = tf.reduce_mean(loss_for_subspace) + losses.append(loss_for_subspace) + + loss = tf.math.add_n(losses) optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate, epsilon=adam_epsilon) optim_op = optimizer.minimize(loss, var_list=self.params) self.sess.run(tf.global_variables_initializer()) + if self.verbose > 0: print("Pretraining with Behavior Cloning...") From be8784befd80373475810edaec703301b12ef9e4 Mon Sep 17 00:00:00 2001 From: William Guss Date: Tue, 16 Jul 2019 10:35:45 -0600 Subject: [PATCH 2/5] Update a2c.py --- stable_baselines/a2c/a2c.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/stable_baselines/a2c/a2c.py b/stable_baselines/a2c/a2c.py index 55e0662bbb..f76048699b 100644 --- a/stable_baselines/a2c/a2c.py +++ b/stable_baselines/a2c/a2c.py @@ -87,7 +87,8 @@ def __init__(self, policy, env, gamma=0.99, n_steps=5, vf_coef=0.25, ent_coef=0. def _get_pretrain_placeholders(self): policy = self.train_model - if isinstance(self.action_space, gym.spaces.Discrete): + if (isinstance(self.action_space, gym.spaces.Discrete) + or isinstance(self.action_space, gym.spaces.MultiDiscrete)): return policy.obs_ph, self.actions_ph, policy.policy return policy.obs_ph, self.actions_ph, policy.deterministic_action From 454dfd2edc152b7559cf060fc7ebd6e7c5fb8ca1 Mon Sep 17 00:00:00 2001 From: William Guss Date: Tue, 16 Jul 2019 10:36:41 -0600 Subject: [PATCH 3/5] Update base_class.py --- stable_baselines/common/base_class.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/stable_baselines/common/base_class.py b/stable_baselines/common/base_class.py index 9cad0eb183..cfe2698a0b 100644 --- a/stable_baselines/common/base_class.py +++ b/stable_baselines/common/base_class.py @@ -231,7 +231,8 @@ def pretrain(self, dataset, n_epochs=10, learning_rate=1e-4, continuous_actions = isinstance(self.action_space, gym.spaces.Box) discrete_actions = isinstance(self.action_space, gym.spaces.Discrete) multidiscrete_actions = isinstance(self.action_space, gym.spaces.MultiDiscrete) - assert discrete_actions or continuous_actions or multidiscrete_actions, 'Only Discrete, Box, or MultiDiscrete action spaces are supported' + assert discrete_actions or continuous_actions or multidiscrete_actions, ( + 'Only Discrete, Box, or MultiDiscrete action spaces are supported') # Validate the model every 10% of the total number of iteration if val_interval is None: From 467153a763816354941bbfa8edda041938ee7c7e Mon Sep 17 00:00:00 2001 From: William Guss Date: Tue, 16 Jul 2019 10:45:18 -0600 Subject: [PATCH 4/5] PEP8 --- stable_baselines/common/base_class.py | 1 - 1 file changed, 1 deletion(-) diff --git a/stable_baselines/common/base_class.py b/stable_baselines/common/base_class.py index cfe2698a0b..a53d732fe1 100644 --- a/stable_baselines/common/base_class.py +++ b/stable_baselines/common/base_class.py @@ -285,7 +285,6 @@ def pretrain(self, dataset, n_epochs=10, learning_rate=1e-4, self.sess.run(tf.global_variables_initializer()) - if self.verbose > 0: print("Pretraining with Behavior Cloning...") From 80be62c9d2f12131b55f63173e376db1d081321f Mon Sep 17 00:00:00 2001 From: William Guss Date: Tue, 16 Jul 2019 10:53:05 -0600 Subject: [PATCH 5/5] Update changelog.rst --- docs/misc/changelog.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index e0194726e5..820c369620 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -14,6 +14,7 @@ Breaking Changes: New Features: ^^^^^^^^^^^^^ +- added support for :code:`gym.spaces.MultiDiscrete` spaces in pretraining. Bug Fixes: ^^^^^^^^^^