Skip to content

Commit

Permalink
feat: ✨ update space to gym.spaces.Dict
Browse files Browse the repository at this point in the history
  • Loading branch information
leoxhwang committed Jul 7, 2024
1 parent 5738657 commit d182f79
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 26 deletions.
36 changes: 23 additions & 13 deletions smac_pettingzoo/smacv1_pettingzoo_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Any, Dict, List, Tuple, Type

import co_mas
import gymnasium.spaces as gspaces
import gymnasium as gym
import numpy as np
import pettingzoo
import pettingzoo.utils
Expand All @@ -24,15 +24,25 @@ def __init__(self, map_name: str, smacv1_env_args: dict = {}):
self._env.reset(0)
self._init_agents()

self.observation_spaces = {
agent: gspaces.Box(low=-1, high=1, shape=(self._env.observation_space[agent_i][0],), dtype=np.float32)
for agent_i, agent in enumerate(self.agents)
}
self.action_spaces = {agent: self._env.action_space[agent_i] for agent_i, agent in enumerate(self.agents)}
self.state_spaces = {
agent: gspaces.Box(low=-1, high=1, shape=(self._env.share_observation_space[agent_i][0],), dtype=np.float32)
for agent_i, agent in enumerate(self.agents)
}
self.observation_spaces = gym.spaces.Dict(
{
agent: gym.spaces.Box(
low=-1, high=1, shape=(self._env.observation_space[agent_i][0],), dtype=np.float32
)
for agent_i, agent in enumerate(self.agents)
}
)
self.action_spaces = gym.spaces.Dict(
{agent: self._env.action_space[agent_i] for agent_i, agent in enumerate(self.agents)}
)
self.state_spaces = gym.spaces.Dict(
{
agent: gym.spaces.Box(
low=-1, high=1, shape=(self._env.share_observation_space[agent_i][0],), dtype=np.float32
)
for agent_i, agent in enumerate(self.agents)
}
)

self.states = None

Expand Down Expand Up @@ -73,13 +83,13 @@ def _init_agents(self):
self.possible_agents = self.agents[:]
self.agents_to_agent_ids = {agent: agent_id for agent_id, agent in enumerate(self.possible_agents)}

def observation_space(self, agent: Any) -> gspaces.Space:
def observation_space(self, agent: Any) -> gym.spaces.Space:
return self.observation_spaces[agent]

def action_space(self, agent: Any) -> gspaces.Space:
def action_space(self, agent: Any) -> gym.spaces.Space:
return self.action_spaces[agent]

def state_space(self, agent: Any) -> gspaces.Space:
def state_space(self, agent: Any) -> gym.spaces.Space:
return self.state_spaces[agent]

def state(self) -> Dict:
Expand Down
36 changes: 23 additions & 13 deletions smac_pettingzoo/smacv2_pettingzoo_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Any, Dict, List, Tuple, Type

import co_mas
import gymnasium.spaces as gspaces
import gymnasium as gym
import numpy as np
import pettingzoo
from loguru import logger
Expand Down Expand Up @@ -61,15 +61,25 @@ def __init__(
self._env.reset(0)
self._init_agents()

self.observation_spaces = {
agent: gspaces.Box(low=-1, high=1, shape=(self._env.observation_space[agent_i][0],), dtype=np.float32)
for agent_i, agent in enumerate(self.agents)
}
self.action_spaces = {agent: self._env.action_space[agent_i] for agent_i, agent in enumerate(self.agents)}
self.state_spaces = {
agent: gspaces.Box(low=-1, high=1, shape=(self._env.share_observation_space[agent_i][0],), dtype=np.float32)
for agent_i, agent in enumerate(self.agents)
}
self.observation_spaces = gym.spaces.Dict(
{
agent: gym.spaces.Box(
low=-1, high=1, shape=(self._env.observation_space[agent_i][0],), dtype=np.float32
)
for agent_i, agent in enumerate(self.agents)
}
)
self.action_spaces = gym.spaces.Dict(
{agent: self._env.action_space[agent_i] for agent_i, agent in enumerate(self.agents)}
)
self.state_spaces = gym.spaces.Dict(
{
agent: gym.spaces.Box(
low=-1, high=1, shape=(self._env.share_observation_space[agent_i][0],), dtype=np.float32
)
for agent_i, agent in enumerate(self.agents)
}
)

self.states = None

Expand Down Expand Up @@ -160,13 +170,13 @@ def _init_agents(self):
self.possible_agents = self.agents[:]
self.agents_to_agent_ids = {agent: agent_id for agent_id, agent in enumerate(self.possible_agents)}

def observation_space(self, agent: Any) -> gspaces.Space:
def observation_space(self, agent: Any) -> gym.spaces.Space:
return self.observation_spaces[agent]

def action_space(self, agent: Any) -> gspaces.Space:
def action_space(self, agent: Any) -> gym.spaces.Space:
return self.action_spaces[agent]

def state_space(self, agent: Any) -> gspaces.Space:
def state_space(self, agent: Any) -> gym.spaces.Space:
return self.state_spaces[agent]

def state(self) -> Dict:
Expand Down

0 comments on commit d182f79

Please sign in to comment.