Skip to content

Commit

Permalink
feat: 👽 update due to co_mas APIs change
Browse files Browse the repository at this point in the history
  • Loading branch information
leoxhwang committed Jul 6, 2024
1 parent 988b3d6 commit 5738657
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 28 deletions.
4 changes: 2 additions & 2 deletions smac_pettingzoo/env/smacv1.py
Original file line number Diff line number Diff line change
Expand Up @@ -1779,11 +1779,11 @@ def get_avail_agent_actions(self, agent_id):
if dist <= shoot_range:
avail_actions[t_id + self.n_actions_no_attack] = 1

return avail_actions
return np.array(avail_actions)

else:
# only no-op allowed
return [1] + [0] * (self.n_actions - 1)
return np.array([1] + [0] * (self.n_actions - 1))

def get_avail_actions(self):
"""Returns the available actions of all agents in a list."""
Expand Down
4 changes: 2 additions & 2 deletions smac_pettingzoo/env/smacv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2260,11 +2260,11 @@ def get_avail_agent_actions(self, agent_id):
if can_shoot:
avail_actions[t_id + self.n_actions_no_attack] = 1

return avail_actions
return np.array(avail_actions)

else:
# only no-op allowed
return [1] + [0] * (self.n_actions - 1)
return np.array([1] + [0] * (self.n_actions - 1))

def get_avail_actions(self):
"""Returns the available actions of all agents in a list."""
Expand Down
3 changes: 2 additions & 1 deletion smac_pettingzoo/smacv1_pettingzoo_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,8 @@ def step(self, actions: Dict) -> Tuple[Dict, Dict, Dict, Dict, Dict]:
return observations, rewards, terminations, truncations, infos

def close(self):
self._env.close()
if hasattr(self, "_env") and self._env is not None:
self._env.close()


def parallel_env(
Expand Down
3 changes: 2 additions & 1 deletion smac_pettingzoo/smacv2_pettingzoo_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,8 @@ def step(self, actions: Dict) -> Tuple[Dict, Dict, Dict, Dict, Dict]:
return observations, rewards, terminations, truncations, infos

def close(self):
self._env.close()
if hasattr(self, "_env") and self._env is not None:
self._env.close()


def parallel_env(
Expand Down
23 changes: 15 additions & 8 deletions tests/smacv1_pettingzoo_v1_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from co_mas.test import parallel_api_test, sample_action
import numpy as np
from co_mas.test.parallel_api import parallel_api_test, sample_action
from loguru import logger

from smac_pettingzoo import smacv1_pettingzoo_v1
Expand All @@ -10,7 +11,7 @@
step = 0
while True:
obs, _, terminated, truncated, info = env.step(
{agent: sample_action(env, obs, agent, info) for agent in env.agents}
{agent: sample_action(agent, obs[agent], info[agent], env.action_space(agent)) for agent in env.agents}
)
step += 1
if len(env.agents) <= 0:
Expand All @@ -28,11 +29,12 @@
env1 = smacv1_pettingzoo_v1.parallel_env("8m", {})
obs1_list = []
obs1, info1 = env1.reset(seed=42)
np.random.seed(42)
obs1_list.append(obs1)

while True:
obs1, _, terminated1, _, info1 = env1.step(
{agent: sample_action(env1, obs1, agent, info1) for agent in env1.agents}
{agent: sample_action(agent, obs1[agent], info1[agent], env1.action_space(agent)) for agent in env1.agents}
)
obs1_list.append(obs1)

Expand All @@ -45,25 +47,26 @@

obs2_list = []
obs2, info2 = env2.reset(seed=42)
np.random.seed(42)
obs2_list.append(obs2)

while True:
obs2, _, terminated2, _, info2 = env2.step(
{agent: sample_action(env2, obs2, agent, info2) for agent in env2.agents}
{agent: sample_action(agent, obs2[agent], info2[agent], env2.action_space(agent)) for agent in env2.agents}
)
obs2_list.append(obs2)

if any(terminated2.values()):
break

env2.close()

for i, (obs1, obs2) in enumerate(zip(obs1_list, obs2_list)):
assert all(
(obs1[agent] == obs2[agent]).all() for agent in env1.agents
), f"Observations at step {i} differ:\n{obs1}\n{obs2}"

logger.success("Seed Test Passed!")
env1.close()
env2.close()

# Wrapper Tests
from co_mas.wrappers import AutoResetParallelEnvWrapper, OrderForcingParallelEnvWrapper
Expand All @@ -75,12 +78,16 @@
obs, info = env.reset(seed=42)

while True:
obs, _, terminated, _, info = env.step({agent: sample_action(env, obs, agent, info) for agent in env.agents})
obs, _, terminated, _, info = env.step(
{agent: sample_action(agent, obs[agent], info[agent], env.action_space(agent)) for agent in env.agents}
)

if all(terminated.values()):
break

obs, _, terminated, _, info = env.step({agent: sample_action(env, obs, agent, info) for agent in env.agents})
obs, _, terminated, _, info = env.step(
{agent: sample_action(agent, obs[agent], info[agent], env.action_space(agent)) for agent in env.agents}
)

assert terminated != {agent: True for agent in env.agents}

Expand Down
35 changes: 21 additions & 14 deletions tests/smacv2_pettingzoo_v1_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from co_mas.test import parallel_api_test, sample_action
import numpy as np
from co_mas.test.parallel_api import parallel_api_test, sample_action
from loguru import logger

from smac_pettingzoo import smacv2_pettingzoo_v1
Expand All @@ -10,7 +11,7 @@
step = 0
while True:
obs, _, terminated, truncated, info = env.step(
{agent: sample_action(env, obs, agent, info) for agent in env.agents}
{agent: sample_action(agent, obs[agent], info[agent], env.action_space(agent)) for agent in env.agents}
)
step += 1
if len(env.agents) <= 0:
Expand All @@ -19,24 +20,25 @@
ep_i += 1
if any(truncated.values()):
break

env.close()

parallel_api_test(env, 200)
parallel_api_test(env, 400)

# Seed Tests
env1 = smacv2_pettingzoo_v1.parallel_env("10gen_terran_10_vs_10")
obs1_list = []
obs1, info1 = env1.reset(seed=42)
np.random.seed(42)
obs1_list.append(obs1)

while True:
obs1, _, terminated1, truncated1, info1 = env1.step(
{agent: sample_action(env1, obs1, agent, info1) for agent in env1.agents}
obs1, _, terminated1, _, info1 = env1.step(
{agent: sample_action(agent, obs1[agent], info1[agent], env1.action_space(agent)) for agent in env1.agents}
)
obs1_list.append(obs1)
logger.trace(f"{env1.agents}")

if len(env1.agents) <= 0:
if any(terminated1.values()):
break

env1.close()
Expand All @@ -45,25 +47,26 @@

obs2_list = []
obs2, info2 = env2.reset(seed=42)
np.random.seed(42)
obs2_list.append(obs2)

while True:
obs2, _, terminated2, _, info2 = env2.step(
{agent: sample_action(env2, obs2, agent, info2) for agent in env2.agents}
{agent: sample_action(agent, obs2[agent], info2[agent], env2.action_space(agent)) for agent in env2.agents}
)
obs2_list.append(obs2)

if len(env2.agents) <= 0:
if any(terminated2.values()):
break

env2.close()

for i, (obs1, obs2) in enumerate(zip(obs1_list, obs2_list)):
assert all(
(obs1[agent] == obs2[agent]).all() for agent in env1.agents
), f"Observations at step {i} differ:\n{obs1}\n{obs2}"

logger.success("Seed Test Passed!")
env1.close()
env2.close()

# Wrapper Tests
from co_mas.wrappers import AutoResetParallelEnvWrapper, OrderForcingParallelEnvWrapper
Expand All @@ -75,12 +78,16 @@
obs, info = env.reset(seed=42)

while True:
obs, _, terminated, _, info = env.step({agent: sample_action(env, obs, agent, info) for agent in env.agents})
obs, _, terminated, _, info = env.step(
{agent: sample_action(agent, obs[agent], info[agent], env.action_space(agent)) for agent in env.agents}
)

if len(env.agents) <= 0:
if all(terminated.values()):
break

obs, _, terminated, _, info = env.step({agent: sample_action(env, obs, agent, info) for agent in env.agents})
obs, _, terminated, _, info = env.step(
{agent: sample_action(agent, obs[agent], info[agent], env.action_space(agent)) for agent in env.agents}
)

assert terminated != {agent: True for agent in env.agents}

Expand Down

0 comments on commit 5738657

Please sign in to comment.