Skip to content

Commit

Permalink
polish(pu): adapt qmix's mixer to support image obs
Browse files Browse the repository at this point in the history
  • Loading branch information
puyuan1996 committed Oct 28, 2024
1 parent e916841 commit 55dc254
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 7 deletions.
17 changes: 13 additions & 4 deletions ding/model/template/qmix.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from functools import reduce
from ding.utils import list_split, MODEL_REGISTRY
from ding.torch_utils import fc_block, MLP
from ..common import ConvEncoder
from .q_learning import DRQN


Expand Down Expand Up @@ -111,7 +112,7 @@ def __init__(
self,
agent_num: int,
obs_shape: int,
global_obs_shape: int,
global_obs_shape: Union[int, List[int]],
action_shape: int,
hidden_size_list: list,
mixer: bool = True,
Expand Down Expand Up @@ -146,8 +147,14 @@ def __init__(
embedding_size = hidden_size_list[-1]
self.mixer = mixer
if self.mixer:
self._mixer = Mixer(agent_num, global_obs_shape, embedding_size, activation=activation)
self._global_state_encoder = nn.Identity()
if len(global_obs_shape) == 1:
self._mixer = Mixer(agent_num, global_obs_shape, embedding_size, activation=activation)
self._global_state_encoder = nn.Identity()
elif len(global_obs_shape) == 3:
self._mixer = Mixer(agent_num, embedding_size, embedding_size, activation=activation)
self._global_state_encoder = ConvEncoder(global_obs_shape, hidden_size_list=hidden_size_list, activation=activation, norm_type='BN')
else:
raise ValueError("Not support global_obs_shape: {}".format(global_obs_shape))

def forward(self, data: dict, single_step: bool = True) -> dict:
"""
Expand Down Expand Up @@ -183,7 +190,9 @@ def forward(self, data: dict, single_step: bool = True) -> dict:
'prev_state']
action = data.get('action', None)
if single_step:
agent_state, global_state = agent_state.unsqueeze(0), global_state.unsqueeze(0)
agent_state = agent_state.unsqueeze(0)
if single_step and len(global_state.shape) == 2:
global_state = global_state.unsqueeze(0)
T, B, A = agent_state.shape[:3]
assert len(prev_state) == B and all(
[len(p) == A for p in prev_state]
Expand Down
6 changes: 3 additions & 3 deletions dizoo/petting_zoo/config/ptz_pistonball_qmix_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
shared_memory=False,
reset_timeout=6000,
),
max_env_step=3e6,
),
policy=dict(
cuda=True,
Expand All @@ -30,8 +31,7 @@
global_obs_shape=(3, 560, 880), # Global state shape
action_shape=3, # Discrete actions (0, 1, 2)
hidden_size_list=[128, 128, 64],
# mixer=True, # TODO: mixer is not supported image observation now
mixer=False,
mixer=True,
),
learn=dict(
update_per_collect=100,
Expand Down Expand Up @@ -73,4 +73,4 @@
if __name__ == '__main__':
# or you can enter `ding -m serial -c ptz_pistonball_qmix_config.py -s 0`
from ding.entry import serial_pipeline
serial_pipeline((main_config, create_config), seed=0)
serial_pipeline((main_config, create_config), seed=0, max_env_step=main_config.env.max_env_step)

0 comments on commit 55dc254

Please sign in to comment.