Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pettingzoo Integration #32

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# SCM syntax highlighting
pixi.lock linguist-language=YAML linguist-generated=true
6 changes: 5 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
pixi.lock

# Created by https://www.gitignore.io/api/linux,python,windows,pycharm+all,visualstudiocode
# Edit at https://www.gitignore.io/?templates=linux,python,windows,pycharm+all,visualstudiocode
Expand Down Expand Up @@ -245,4 +246,7 @@ $RECYCLE.BIN/
# Windows shortcuts
*.lnk

# End of https://www.gitignore.io/api/linux,python,windows,pycharm+all,visualstudiocode
# End of https://www.gitignore.io/api/linux,python,windows,pycharm+all,visualstudiocode
# pixi environments
.pixi
*.egg-info
48 changes: 48 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
[project]
name = "rware"
version = "2.0.0"
description = "Multi-Robot Warehouse environment for reinforcement learning"
readme = { content-type = "text/markdown", file = "README.md" }
maintainers = [{ name = "Filippos Christianos" }]
classifiers = [
"Intended Audience :: Developers",
"Programming Language :: Python :: 3.7",
]
requires-python = ">=3.7"
urls = { github = "https://github.com/semitable/robotic-warehouse" }
dependencies = ["numpy", "gymnasium", "pyglet<2", "networkx"]

[build-system]
requires = ["setuptools"]
build-backend = "setuptools.build_meta"

[project.optional-dependencies]
test = ["pytest"]
pettingzoo = ["pettingzoo"]

[tool.setuptools.packages.find]
exclude = ["contrib", "docs", "tests"]

# pixi
[tool.pixi.workspace]
channels = ["conda-forge"]
platforms = ["linux-64"]
preview = ["pixi-build"]

[tool.pixi.environments]
default = { solve-group = "default" }
test = { features = ["test", "pettingzoo"], solve-group = "default" }

[tool.pixi.pypi-dependencies]
rware = { path = ".", editable = true }

[tool.pixi.package]
name = "rware"
version = "2.0.1"

[tool.pixi.build-system]
build-backend = { name = "pixi-build-python", version = "*" }
channels = ["pixi-build-backends", "conda-forge"]

[tool.pixi.feature.test.tasks]
test = "pytest"
98 changes: 98 additions & 0 deletions rware/pettingzoo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
from typing import Dict, Tuple, List, Optional
import warnings

import gymnasium as gym
import numpy as np
from pettingzoo import ParallelEnv

from .warehouse import Warehouse

# ID are str(integers), which represent the agent.id (agent idx+1) in env.agents.
# Set to str for compatability with TorchRL.
AgentID = str
# TODO: Refactor Action object to include the message bits.
ActionType = object
ObsType = np.ndarray


def to_agentid_dict(data: List):
return {str(i + 1): x for i, x in enumerate(data)}


class PettingZooWrapper(ParallelEnv):
"""Wraps a Warehouse Env object to be compatible with the PettingZoo ParallelEnv API."""

def __init__(self, env: Warehouse):
super().__init__()
self._env = env
self.agents = self.possible_agents = []
self.observation_spaces = self.action_spaces = {}

def reset(self, seed: Optional[int] = None, options: Optional[Dict] = None):
obs, info = self._env.reset(seed, options)
obs = to_agentid_dict(obs)
info = {str(i + 1): {} for i in range(self._env.n_agents)}
# Reset agents and spaces
self.agents = [str(agent.id) for agent in self._env.agents]
self.possible_agents = self.agents
self.observation_spaces = {
agent_id: self.observation_space(agent_id)
for agent_id in [str(i + 1) for i in range(self._env.n_agents)]
}
self.action_spaces = {
agent_id: self.action_space(agent_id)
for agent_id in [str(i + 1) for i in range(self._env.n_agents)]
}
return obs, info

def step(self, actions: dict[AgentID, ActionType]) -> Tuple[
dict[AgentID, ObsType],
dict[AgentID, float],
dict[AgentID, bool],
dict[AgentID, bool],
dict[AgentID, dict],
]:
# Unwrap to list of actions
actions_unwrapped = [(int(id_) - 1, action) for id_, action in actions.items()]
actions_unwrapped.sort(key=lambda x: x[0])
actions_unwrapped = [x[1] for x in actions_unwrapped]
assert (
len(actions_unwrapped) == self._env.n_agents
), f"Incorrect number of actions provided. Expected {self._env.n_agents} but got {len(actions_unwrapped)}"

# Step inner environment
obs, rewards, terminated, truncated, info = self._env.step(actions_unwrapped)

# Transform to PettingZoo output
obs = to_agentid_dict(obs)
rewards = to_agentid_dict(rewards)
if terminated or truncated:
self.agents = [] # PettingZoo requires agents to be removed
terminated = to_agentid_dict([terminated for _ in range(self._env.n_agents)])
truncated = to_agentid_dict([truncated for _ in range(self._env.n_agents)])
if len(info) != 0:
warnings.warn(
"Error: expected info dict to be empty. PettingZooWrapper is likely out of date."
)
info = {str(i + 1): {} for i in range(self._env.n_agents)}

return obs, rewards, terminated, truncated, info

def render(self):
return self._env.render()

def close(self) -> None:
self._env.close()

def state(self):
return self._env.get_global_image()

def observation_space(self, agent: AgentID) -> gym.spaces.Space:
space = self._env.observation_space
assert isinstance(space, gym.spaces.Tuple)
return space[int(agent) - 1]

def action_space(self, agent: AgentID) -> gym.spaces.Space:
space = self._env.action_space
assert isinstance(space, gym.spaces.Tuple)
return space[int(agent) - 1]
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,6 @@
"pyglet<2",
"networkx",
],
extras_require={"test": ["pytest"]},
extras_require={"test": ["pytest"], "pettingzoo": ["pettingzoo"]},
include_package_data=True,
)
53 changes: 53 additions & 0 deletions tests/test_integration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from typing import Optional
import importlib
import pytest

from rware.warehouse import Warehouse, RewardType, ObservationType

_has_pettingzoo = importlib.util.find_spec("pettingzoo") is not None
if _has_pettingzoo:
from pettingzoo.test import parallel_api_test
from rware.pettingzoo import PettingZooWrapper


@pytest.mark.parametrize("n_agents", [1, 3])
@pytest.mark.parametrize("msg_bits", [0, 1])
@pytest.mark.parametrize("sensor_range", [1, 3])
@pytest.mark.parametrize("max_inactivity_steps", [None, 10])
@pytest.mark.parametrize("reward_type", [RewardType.GLOBAL, RewardType.INDIVIDUAL])
@pytest.mark.parametrize(
"observation_type",
[
ObservationType.DICT,
ObservationType.IMAGE,
ObservationType.IMAGE_DICT,
ObservationType.FLATTENED,
],
)
def test_pettingzoo_wrapper(
n_agents: int,
msg_bits: int,
sensor_range: int,
max_inactivity_steps: Optional[int],
reward_type: RewardType,
observation_type: ObservationType,
):
if not _has_pettingzoo:
pytest.skip("PettingZoo not available.")
return

env = Warehouse(
shelf_columns=1,
column_height=5,
shelf_rows=3,
n_agents=n_agents,
msg_bits=msg_bits,
sensor_range=sensor_range,
request_queue_size=5,
max_inactivity_steps=max_inactivity_steps,
max_steps=None,
reward_type=reward_type,
observation_type=observation_type,
)
env = PettingZooWrapper(env)
parallel_api_test(env)