From 55dc25472238f8f057ff181417e390cc9849ca73 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=92=B2=E6=BA=90?= <2402552459@qq.com> Date: Mon, 28 Oct 2024 15:06:45 +0800 Subject: [PATCH] polish(pu): adapt qmix's mixer to support image obs --- ding/model/template/qmix.py | 17 +++++++++++++---- .../config/ptz_pistonball_qmix_config.py | 6 +++--- 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/ding/model/template/qmix.py b/ding/model/template/qmix.py index 68354e0cf7..a4af8e0ba8 100644 --- a/ding/model/template/qmix.py +++ b/ding/model/template/qmix.py @@ -5,6 +5,7 @@ from functools import reduce from ding.utils import list_split, MODEL_REGISTRY from ding.torch_utils import fc_block, MLP +from ..common import ConvEncoder from .q_learning import DRQN @@ -111,7 +112,7 @@ def __init__( self, agent_num: int, obs_shape: int, - global_obs_shape: int, + global_obs_shape: Union[int, List[int]], action_shape: int, hidden_size_list: list, mixer: bool = True, @@ -146,8 +147,14 @@ def __init__( embedding_size = hidden_size_list[-1] self.mixer = mixer if self.mixer: - self._mixer = Mixer(agent_num, global_obs_shape, embedding_size, activation=activation) - self._global_state_encoder = nn.Identity() + if len(global_obs_shape) == 1: + self._mixer = Mixer(agent_num, global_obs_shape, embedding_size, activation=activation) + self._global_state_encoder = nn.Identity() + elif len(global_obs_shape) == 3: + self._mixer = Mixer(agent_num, embedding_size, embedding_size, activation=activation) + self._global_state_encoder = ConvEncoder(global_obs_shape, hidden_size_list=hidden_size_list, activation=activation, norm_type='BN') + else: + raise ValueError("Not support global_obs_shape: {}".format(global_obs_shape)) def forward(self, data: dict, single_step: bool = True) -> dict: """ @@ -183,7 +190,9 @@ def forward(self, data: dict, single_step: bool = True) -> dict: 'prev_state'] action = data.get('action', None) if single_step: - agent_state, global_state = agent_state.unsqueeze(0), global_state.unsqueeze(0) + agent_state = agent_state.unsqueeze(0) + if single_step and len(global_state.shape) == 2: + global_state = global_state.unsqueeze(0) T, B, A = agent_state.shape[:3] assert len(prev_state) == B and all( [len(p) == A for p in prev_state] diff --git a/dizoo/petting_zoo/config/ptz_pistonball_qmix_config.py b/dizoo/petting_zoo/config/ptz_pistonball_qmix_config.py index f1b2da682a..7af0289d89 100644 --- a/dizoo/petting_zoo/config/ptz_pistonball_qmix_config.py +++ b/dizoo/petting_zoo/config/ptz_pistonball_qmix_config.py @@ -21,6 +21,7 @@ shared_memory=False, reset_timeout=6000, ), + max_env_step=3e6, ), policy=dict( cuda=True, @@ -30,8 +31,7 @@ global_obs_shape=(3, 560, 880), # Global state shape action_shape=3, # Discrete actions (0, 1, 2) hidden_size_list=[128, 128, 64], - # mixer=True, # TODO: mixer is not supported image observation now - mixer=False, + mixer=True, ), learn=dict( update_per_collect=100, @@ -73,4 +73,4 @@ if __name__ == '__main__': # or you can enter `ding -m serial -c ptz_pistonball_qmix_config.py -s 0` from ding.entry import serial_pipeline - serial_pipeline((main_config, create_config), seed=0) \ No newline at end of file + serial_pipeline((main_config, create_config), seed=0, max_env_step=main_config.env.max_env_step) \ No newline at end of file