From d182f79d37ddf3f5e72eb8559a50f30693e75e11 Mon Sep 17 00:00:00 2001 From: leoxhwang <1134086740@qq.com> Date: Sun, 7 Jul 2024 21:47:41 +0800 Subject: [PATCH] feat: :sparkles: update space to gym.spaces.Dict --- smac_pettingzoo/smacv1_pettingzoo_v1.py | 36 ++++++++++++++++--------- smac_pettingzoo/smacv2_pettingzoo_v1.py | 36 ++++++++++++++++--------- 2 files changed, 46 insertions(+), 26 deletions(-) diff --git a/smac_pettingzoo/smacv1_pettingzoo_v1.py b/smac_pettingzoo/smacv1_pettingzoo_v1.py index a2ede55..b0b04f2 100644 --- a/smac_pettingzoo/smacv1_pettingzoo_v1.py +++ b/smac_pettingzoo/smacv1_pettingzoo_v1.py @@ -2,7 +2,7 @@ from typing import Any, Dict, List, Tuple, Type import co_mas -import gymnasium.spaces as gspaces +import gymnasium as gym import numpy as np import pettingzoo import pettingzoo.utils @@ -24,15 +24,25 @@ def __init__(self, map_name: str, smacv1_env_args: dict = {}): self._env.reset(0) self._init_agents() - self.observation_spaces = { - agent: gspaces.Box(low=-1, high=1, shape=(self._env.observation_space[agent_i][0],), dtype=np.float32) - for agent_i, agent in enumerate(self.agents) - } - self.action_spaces = {agent: self._env.action_space[agent_i] for agent_i, agent in enumerate(self.agents)} - self.state_spaces = { - agent: gspaces.Box(low=-1, high=1, shape=(self._env.share_observation_space[agent_i][0],), dtype=np.float32) - for agent_i, agent in enumerate(self.agents) - } + self.observation_spaces = gym.spaces.Dict( + { + agent: gym.spaces.Box( + low=-1, high=1, shape=(self._env.observation_space[agent_i][0],), dtype=np.float32 + ) + for agent_i, agent in enumerate(self.agents) + } + ) + self.action_spaces = gym.spaces.Dict( + {agent: self._env.action_space[agent_i] for agent_i, agent in enumerate(self.agents)} + ) + self.state_spaces = gym.spaces.Dict( + { + agent: gym.spaces.Box( + low=-1, high=1, shape=(self._env.share_observation_space[agent_i][0],), dtype=np.float32 + ) + for agent_i, agent in enumerate(self.agents) + } + ) self.states = None @@ -73,13 +83,13 @@ def _init_agents(self): self.possible_agents = self.agents[:] self.agents_to_agent_ids = {agent: agent_id for agent_id, agent in enumerate(self.possible_agents)} - def observation_space(self, agent: Any) -> gspaces.Space: + def observation_space(self, agent: Any) -> gym.spaces.Space: return self.observation_spaces[agent] - def action_space(self, agent: Any) -> gspaces.Space: + def action_space(self, agent: Any) -> gym.spaces.Space: return self.action_spaces[agent] - def state_space(self, agent: Any) -> gspaces.Space: + def state_space(self, agent: Any) -> gym.spaces.Space: return self.state_spaces[agent] def state(self) -> Dict: diff --git a/smac_pettingzoo/smacv2_pettingzoo_v1.py b/smac_pettingzoo/smacv2_pettingzoo_v1.py index 117ab73..745c59a 100644 --- a/smac_pettingzoo/smacv2_pettingzoo_v1.py +++ b/smac_pettingzoo/smacv2_pettingzoo_v1.py @@ -3,7 +3,7 @@ from typing import Any, Dict, List, Tuple, Type import co_mas -import gymnasium.spaces as gspaces +import gymnasium as gym import numpy as np import pettingzoo from loguru import logger @@ -61,15 +61,25 @@ def __init__( self._env.reset(0) self._init_agents() - self.observation_spaces = { - agent: gspaces.Box(low=-1, high=1, shape=(self._env.observation_space[agent_i][0],), dtype=np.float32) - for agent_i, agent in enumerate(self.agents) - } - self.action_spaces = {agent: self._env.action_space[agent_i] for agent_i, agent in enumerate(self.agents)} - self.state_spaces = { - agent: gspaces.Box(low=-1, high=1, shape=(self._env.share_observation_space[agent_i][0],), dtype=np.float32) - for agent_i, agent in enumerate(self.agents) - } + self.observation_spaces = gym.spaces.Dict( + { + agent: gym.spaces.Box( + low=-1, high=1, shape=(self._env.observation_space[agent_i][0],), dtype=np.float32 + ) + for agent_i, agent in enumerate(self.agents) + } + ) + self.action_spaces = gym.spaces.Dict( + {agent: self._env.action_space[agent_i] for agent_i, agent in enumerate(self.agents)} + ) + self.state_spaces = gym.spaces.Dict( + { + agent: gym.spaces.Box( + low=-1, high=1, shape=(self._env.share_observation_space[agent_i][0],), dtype=np.float32 + ) + for agent_i, agent in enumerate(self.agents) + } + ) self.states = None @@ -160,13 +170,13 @@ def _init_agents(self): self.possible_agents = self.agents[:] self.agents_to_agent_ids = {agent: agent_id for agent_id, agent in enumerate(self.possible_agents)} - def observation_space(self, agent: Any) -> gspaces.Space: + def observation_space(self, agent: Any) -> gym.spaces.Space: return self.observation_spaces[agent] - def action_space(self, agent: Any) -> gspaces.Space: + def action_space(self, agent: Any) -> gym.spaces.Space: return self.action_spaces[agent] - def state_space(self, agent: Any) -> gspaces.Space: + def state_space(self, agent: Any) -> gym.spaces.Space: return self.state_spaces[agent] def state(self) -> Dict: