diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 1bf981d237..dc52601045 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -19,6 +19,8 @@ Breaking Changes: New Features: ^^^^^^^^^^^^^ +- added support for :code:`gym.spaces.MultiDiscrete` spaces in pretraining. (@MadcowD) + Bug Fixes: ^^^^^^^^^^ @@ -451,4 +453,4 @@ In random order... Thanks to @bjmuld @iambenzo @iandanforth @r7vme @brendenpetersen @huvar @abhiskk @JohannesAck @EliasHasle @mrakgr @Bleyddyn @antoine-galataud @junhyeokahn @AdamGleave @keshaviyengar @tperol @XMaster96 @kantneel @Pastafarianist @GerardMaggiolino @PatrickWalter214 @yutingsz @sc420 @Aaahh @billtubbs -@Miffyli @dwiel @miguelrass @qxcv @jaberkow +@Miffyli @dwiel @miguelrass @qxcv @jaberkow @MadcowD diff --git a/stable_baselines/a2c/a2c.py b/stable_baselines/a2c/a2c.py index 55e0662bbb..f76048699b 100644 --- a/stable_baselines/a2c/a2c.py +++ b/stable_baselines/a2c/a2c.py @@ -87,7 +87,8 @@ def __init__(self, policy, env, gamma=0.99, n_steps=5, vf_coef=0.25, ent_coef=0. def _get_pretrain_placeholders(self): policy = self.train_model - if isinstance(self.action_space, gym.spaces.Discrete): + if (isinstance(self.action_space, gym.spaces.Discrete) + or isinstance(self.action_space, gym.spaces.MultiDiscrete)): return policy.obs_ph, self.actions_ph, policy.policy return policy.obs_ph, self.actions_ph, policy.deterministic_action diff --git a/stable_baselines/common/base_class.py b/stable_baselines/common/base_class.py index 14e892a88b..0241536217 100644 --- a/stable_baselines/common/base_class.py +++ b/stable_baselines/common/base_class.py @@ -230,8 +230,9 @@ def pretrain(self, dataset, n_epochs=10, learning_rate=1e-4, """ continuous_actions = isinstance(self.action_space, gym.spaces.Box) discrete_actions = isinstance(self.action_space, gym.spaces.Discrete) - - assert discrete_actions or continuous_actions, 'Only Discrete and Box action spaces are supported' + multidiscrete_actions = isinstance(self.action_space, gym.spaces.MultiDiscrete) + assert discrete_actions or continuous_actions or multidiscrete_actions, ( + 'Only Discrete, Box, or MultiDiscrete action spaces are supported') # Validate the model every 10% of the total number of iteration if val_interval is None: @@ -246,7 +247,7 @@ def pretrain(self, dataset, n_epochs=10, learning_rate=1e-4, if continuous_actions: obs_ph, actions_ph, deterministic_actions_ph = self._get_pretrain_placeholders() loss = tf.reduce_mean(tf.square(actions_ph - deterministic_actions_ph)) - else: + elif discrete_actions: obs_ph, actions_ph, actions_logits_ph = self._get_pretrain_placeholders() # actions_ph has a shape if (n_batch,), we reshape it to (n_batch, 1) # so no additional changes is needed in the dataloader @@ -257,6 +258,28 @@ def pretrain(self, dataset, n_epochs=10, learning_rate=1e-4, labels=tf.stop_gradient(one_hot_actions) ) loss = tf.reduce_mean(loss) + elif multidiscrete_actions: + losses = [] + obs_ph, actions_ph, logits = _get_pretrain_placeholders(self) + + n_actions = len(self.action_space.nvec) + action_indices = [0] + np.cumsum(self.action_space.nvec).tolist() + action_phs = [tf.one_hot(actions_ph[:,i], depth=self.action_space.nvec[i]) for i in range(n_actions)] + + action_logits_phs = [ + logits[:, action_indices[i]:action_indices[i+1]] for i in range(n_actions) + ] + + for one_hot_actions, action_logits_ph in zip(action_phs, action_logits_phs): + print(one_hot_actions, action_logits_ph) + loss_for_subspace = tf.nn.softmax_cross_entropy_with_logits_v2( + logits=action_logits_ph, + labels=tf.stop_gradient(one_hot_actions) + ) + loss_for_subspace = tf.reduce_mean(loss_for_subspace) + losses.append(loss_for_subspace) + + loss = tf.math.add_n(losses) optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate, epsilon=adam_epsilon) optim_op = optimizer.minimize(loss, var_list=self.params)