diff --git a/ding/model/template/q_learning.py b/ding/model/template/q_learning.py index e7d6e7a395..eaec8fc45b 100644 --- a/ding/model/template/q_learning.py +++ b/ding/model/template/q_learning.py @@ -1118,8 +1118,9 @@ def __init__( gru_bias=gru_bias, ) + # for vector obs, use Identity Encoder, i.e. pass if isinstance(obs_shape, int) or len(obs_shape) == 1: - raise NotImplementedError("not support obs_shape for pre-defined encoder: {}".format(obs_shape)) + pass # replace the embedding layer of Transformer with Conv Encoder elif len(obs_shape) == 3: assert encoder_hidden_size_list[-1] == hidden_size diff --git a/ding/policy/r2d2_gtrxl.py b/ding/policy/r2d2_gtrxl.py index 73b89239f3..9d67289685 100644 --- a/ding/policy/r2d2_gtrxl.py +++ b/ding/policy/r2d2_gtrxl.py @@ -55,11 +55,9 @@ class R2D2GTrXLPolicy(Policy): | ``done`` | calculation. | fake termination env 15 ``collect.n_sample`` int [8, 128] | The number of training samples of a | It varies from | call of collector. | different envs - 16 | ``collect.unroll`` int 25 | unroll length of an iteration | unroll_len>1 + 16 | ``collect.seq`` int 20 | Training sequence length | unroll_len>=seq_len>1 | ``_len`` - 17 | ``collect.seq`` int 20 | Training sequence length | unroll_len>=seq_len>1 - | ``_len`` - 18 | ``learn.init_`` str zero | 'zero' or 'old', how to initialize the | + 17 | ``learn.init_`` str zero | 'zero' or 'old', how to initialize the | | ``memory`` | memory before each training iteration. | == ==================== ======== ============== ======================================== ======================= """ @@ -81,7 +79,7 @@ class R2D2GTrXLPolicy(Policy): discount_factor=0.99, # (int) N-step reward for target q_value estimation nstep=5, - # how many steps to use as burnin + # (int) How many steps to use in burnin phase burnin_step=1, # (int) trajectory length unroll_len=25, @@ -158,7 +156,7 @@ def _init_learn(self) -> None: self._seq_len = self._cfg.seq_len self._value_rescale = self._cfg.learn.value_rescale self._init_memory = self._cfg.learn.init_memory - assert self._init_memory in ['zero', 'old'] + assert self._init_memory in ['zero', 'old'], self._init_memory self._target_model = copy.deepcopy(self._model) @@ -352,7 +350,6 @@ def _init_collect(self) -> None: Collect mode init method. Called by ``self.__init__``. Init unroll length and sequence len, collect model. """ - assert 'unroll_len' not in self._cfg.collect, "Use default unroll_len" self._nstep = self._cfg.nstep self._gamma = self._cfg.discount_factor self._unroll_len = self._cfg.unroll_len diff --git a/dizoo/classic_control/cartpole/config/cartpole_r2d2_gtrxl_config.py b/dizoo/classic_control/cartpole/config/cartpole_r2d2_gtrxl_config.py index 5a2cdac020..a165245fc1 100644 --- a/dizoo/classic_control/cartpole/config/cartpole_r2d2_gtrxl_config.py +++ b/dizoo/classic_control/cartpole/config/cartpole_r2d2_gtrxl_config.py @@ -18,11 +18,11 @@ obs_shape=4, action_shape=2, memory_len=5, # length of transformer memory (can be 0) - hidden_size=256, + hidden_size=64, gru_bias=2., att_layer_num=3, dropout=0., - att_head_num=8, + att_head_num=4, ), discount_factor=0.99, nstep=3, @@ -31,7 +31,7 @@ seq_len=8, # transformer input segment # training sequence: unroll_len - burnin_step - nstep learn=dict( - update_per_collect=8, + update_per_collect=16, batch_size=64, learning_rate=0.0005, target_update_freq=500,