Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for pixel-obs environments #21

Merged
merged 26 commits into from
Mar 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ FROM nvidia/cuda:11.8.0-cudnn8-runtime-ubuntu22.04

# Ensure no installs try to launch interactive screen
ARG DEBIAN_FRONTEND=noninteractive
ENV PYTHONUNBUFFERED=1

# Update packages and install python3.9 and other dependencies
RUN apt-get update -y && \
Expand Down Expand Up @@ -33,14 +34,18 @@ RUN pip install --quiet --upgrade pip setuptools wheel && \
pip install -e . && \
pip install flashbax==0.1.0

ENV SC2PATH /home/app/StarCraftII
# ENV SC2PATH /home/app/StarCraftII
# RUN ./install_environments/smacv1.sh
RUN ./install_environments/smacv2.sh
# RUN ./install_environments/smacv2.sh

# ENV LD_LIBRARY_PATH $LD_LIBRARY_PATH:/root/.mujoco/mujoco210/bin:/usr/lib/nvidia
# ENV SUPPRESS_GR_PROMPT 1
# RUN ./install_environments/mamujoco.sh

RUN ./install_environments/pettingzoo.sh

# RUN ./install_environments/flatland.sh

# Copy all code
COPY ./examples ./examples
COPY ./baselines ./baselines
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,12 @@ We are in the process of migrating our datasets from TF Records to Flashbax Vaul
| 💣SMAC v2 | terran_5_vs_5 <br/> zerg_5_vs_5 <br/> terran_10_vs_10 | 5 <br/> 5 <br/> 10 | Discrete | Vector | Dense | Heterog | [source](https://github.com/oxwhirl/smacv2) |
| 🚅Flatland | 3 Trains <br/> 5 Trains | 3 <br/> 5 | Discrete | Vector | Sparse | Homog | [source](https://flatland.aicrowd.com/intro.html) |
| 🐜MAMuJoCo | 2-HalfCheetah <br/> 2-Ant <br/> 4-Ant | 2 <br/> 2 <br/> 4 | Cont. | Vector | Dense | Heterog <br/> Homog <br/> Homog | [source](https://github.com/schroederdewitt/multiagent_mujoco) |

| 🐻PettingZoo | Pursuit <br/> Co-op Pong | 8 <br/> 2 | Discrete <br/> Discrete | Pixels <br/> Pixels | Dense | Homog <br/> Heterog | [source](https://pettingzoo.farama.org/) |

### Legacy Datasets (still to be migrated to Vault) 👴
| Environment | Scenario | Agents | Act | Obs | Reward | Types | Repo |
|-----|----|----|-----|-----|----|----|-----|
| 🐻PettingZoo | Pursuit <br/> Co-op Pong <br/> PistonBall <br/> KAZ| 8 <br/> 2 <br/> 15 <br/> 2| Discrete <br/> Discrete <br/> Cont. <br/> Discrete | Pixels <br/> Pixels <br/> Pixels <br/> Vector | Dense | Homog <br/> Heterog <br/> Homog <br/> Heterog| [source](https://pettingzoo.farama.org/) |
| 🐻PettingZoo | PistonBall <br/> KAZ| 15 <br/> 2| Cont. <br/> Discrete | Pixels <br/> Vector | Dense | Homog <br/> Heterog| [source](https://pettingzoo.farama.org/) |
| 🏙️CityLearn | 2022_all_phases | 17 | Cont. | Vector | Dense | Homog | [source](https://github.com/intelligent-environments-lab/CityLearn) |
| 🔌Voltage Control | case33_3min_final | 6 | Cont. | Vector | Dense | Homog | [source](https://github.com/Future-Power-Networks/MAPDN) |
| 🔴MPE | simple_adversary | 3 | Discrete. | Vector | Dense | Competitive | [source](https://pettingzoo.farama.org/environments/mpe/simple_adversary/) |
Expand Down
14 changes: 9 additions & 5 deletions baselines/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,21 @@
# limitations under the License.
from absl import app, flags

from og_marl.environments.utils import get_environment
from og_marl.environments import get_environment
from og_marl.loggers import JsonWriter, WandbLogger
from og_marl.offline_dataset import download_and_unzip_vault
from og_marl.replay_buffers import FlashbaxReplayBuffer
from og_marl.tf2.networks import CNNEmbeddingNetwork
from og_marl.tf2.systems import get_system
from og_marl.tf2.utils import set_growing_gpu_memory

set_growing_gpu_memory()

FLAGS = flags.FLAGS
flags.DEFINE_string("env", "smac_v1", "Environment name.")
flags.DEFINE_string("scenario", "3m", "Environment scenario name.")
flags.DEFINE_string("env", "pettingzoo", "Environment name.")
flags.DEFINE_string("scenario", "pursuit", "Environment scenario name.")
flags.DEFINE_string("dataset", "Good", "Dataset type.: 'Good', 'Medium', 'Poor' or 'Replay' ")
flags.DEFINE_string("system", "dbc", "System name.")
flags.DEFINE_string("system", "qmix", "System name.")
flags.DEFINE_integer("seed", 42, "Seed.")
flags.DEFINE_float("trainer_steps", 5e4, "Number of training steps.")
flags.DEFINE_integer("batch_size", 64, "Number of training steps.")
Expand All @@ -43,7 +44,7 @@ def main(_):

env = get_environment(FLAGS.env, FLAGS.scenario)

buffer = FlashbaxReplayBuffer(sequence_length=20, sample_period=2)
buffer = FlashbaxReplayBuffer(sequence_length=20, sample_period=1)

download_and_unzip_vault(FLAGS.env, FLAGS.scenario)

Expand All @@ -65,6 +66,9 @@ def main(_):
)

system_kwargs = {"add_agent_id_to_obs": True}
if FLAGS.scenario == "pursuit":
system_kwargs["observation_embedding_network"] = CNNEmbeddingNetwork()

system = get_system(FLAGS.system, env, logger, **system_kwargs)

system.train_offline(buffer, max_trainer_steps=FLAGS.trainer_steps, json_writer=json_writer)
Expand Down
45 changes: 33 additions & 12 deletions examples/tf2/run_all_baselines.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,45 @@
import os

# import module
import traceback

from og_marl.environments import get_environment
from og_marl.loggers import JsonWriter, WandbLogger
from og_marl.loggers import TerminalLogger, JsonWriter
from og_marl.replay_buffers import FlashbaxReplayBuffer
from og_marl.tf2.networks import CNNEmbeddingNetwork
from og_marl.tf2.systems import get_system
from og_marl.tf2.utils import set_growing_gpu_memory

set_growing_gpu_memory()

os.environ["SUPPRESS_GR_PROMPT"] = 1
# For MAMuJoCo
os.environ["SUPPRESS_GR_PROMPT"] = "1"

scenario_system_configs = {
"smac_v1": {
"3m": {
"systems": ["idrqn", "idrqn+cql", "idrqn+bcq", "qmix+cql", "qmix+bcq", "maicq"],
"systems": ["idrqn", "idrqn+cql", "idrqn+bcq", "qmix+cql", "qmix+bcq", "maicq", "dbc"],
"datasets": ["Good"],
"trainer_steps": 3000,
"trainer_steps": 2000,
"evaluate_every": 1000,
},
},
"mamujoco": {
"2halfcheetah": {
"systems": ["iddpg", "iddpg+cql", "maddpg+cql", "maddpg", "omar"],
"pettingzoo": {
"pursuit": {
"systems": ["idrqn", "idrqn+cql", "idrqn+bcq", "qmix+cql", "qmix+bcq", "maicq", "dbc"],
"datasets": ["Good"],
"trainer_steps": 3000,
"trainer_steps": 2000,
"evaluate_every": 1000,
},
},
# "mamujoco": {
# "2halfcheetah": {
# "systems": ["iddpg", "iddpg+cql", "maddpg+cql", "maddpg", "omar"],
# "datasets": ["Good"],
# "trainer_steps": 3000,
# "evaluate_every": 1000,
# },
# },
}

seeds = [42]
Expand All @@ -44,7 +57,7 @@
"system": env_name,
"seed": seed,
}
logger = WandbLogger(config, project="og-marl-baselines")
logger = TerminalLogger()
env = get_environment(env_name, scenario_name)

buffer = FlashbaxReplayBuffer(sequence_length=20, sample_period=1)
Expand All @@ -55,10 +68,18 @@
raise ValueError("Vault not found. Exiting.")

json_writer = JsonWriter(
"logs", system_name, f"{scenario_name}_{dataset_name}", env_name, seed
"test_all_baselines",
system_name,
f"{scenario_name}_{dataset_name}",
env_name,
seed,
)

system_kwargs = {"add_agent_id_to_obs": True}

if scenario_name == "pursuit":
system_kwargs["observation_embedding_network"] = CNNEmbeddingNetwork()

system = get_system(system_name, env, logger, **system_kwargs)

trainer_steps = scenario_system_configs[env_name][scenario_name][
Expand All @@ -75,7 +96,7 @@
)
except: # noqa: E722
logger.close()
print()
print("BROKEN")
print("BROKEN:", env_name, scenario_name, system_name)
traceback.print_exc()
print()
continue
2 changes: 1 addition & 1 deletion install_environments/requirements/pettingzoo.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ autorom
gym
numpy
opencv-python
pettingzoo==1.22.0
pettingzoo==1.23.1
pygame
pymunk
scipy
Expand Down
1 change: 0 additions & 1 deletion mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ theme:

nav:
- Home: 'index.md'
- Datasets: 'datasets.md'
- Baseline Results: 'baselines.md'
- Updates: 'updates.md'
- API Reference: 'api.md'
Expand Down
10 changes: 9 additions & 1 deletion og_marl/environments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from og_marl.environments.base import BaseEnvironment


def get_environment(env_name: str, scenario: str) -> BaseEnvironment:
def get_environment(env_name: str, scenario: str) -> BaseEnvironment: # noqa: C901
if env_name == "smac_v1":
from og_marl.environments.smacv1 import SMACv1

Expand All @@ -31,6 +31,14 @@ def get_environment(env_name: str, scenario: str) -> BaseEnvironment:
from og_marl.environments.old_mamujoco import MAMuJoCo

return MAMuJoCo(scenario)
elif scenario == "pursuit":
from og_marl.environments.pursuit import Pursuit

return Pursuit()
elif scenario == "coop_pong":
from og_marl.environments.coop_pong import CooperativePong

return CooperativePong()
elif env_name == "gymnasium_mamujoco":
from og_marl.environments.gymnasium_mamujoco import MAMuJoCo

Expand Down
114 changes: 114 additions & 0 deletions og_marl/environments/coop_pong.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# python3
# Copyright 2021 InstaDeep Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Wrapper for Cooperative Pettingzoo environments."""
from typing import Any, List, Dict

import numpy as np
from pettingzoo.butterfly import cooperative_pong_v5
import supersuit

from og_marl.environments.base import BaseEnvironment
from og_marl.environments.base import Observations, ResetReturn, StepReturn


class CooperativePong(BaseEnvironment):
"""Environment wrapper PettingZoo Cooperative Pong."""

def __init__(
self,
) -> None:
"""Constructor."""
self._environment = cooperative_pong_v5.parallel_env(render_mode="rgb_array")
# Wrap environment with supersuit pre-process wrappers
self._environment = supersuit.color_reduction_v0(self._environment, mode="R")
self._environment = supersuit.resize_v0(self._environment, x_size=145, y_size=84)
self._environment = supersuit.dtype_v0(self._environment, dtype="float32")
self._environment = supersuit.normalize_obs_v0(self._environment)

self._agents = self._environment.possible_agents
self._done = False
self.max_episode_length = 900

def reset(self) -> ResetReturn:
"""Resets the env."""
# Reset the environment
observations, _ = self._environment.reset() # type: ignore

# Convert observations
observations = self._convert_observations(observations)

# Global state
env_state = self._create_state_representation(observations, first=True)

# Infos
info = {"state": env_state, "legals": self._legals}

return observations, info

def step(self, actions: Dict[str, np.ndarray]) -> StepReturn:
"""Steps in env."""
# Step the environment
observations, rewards, terminals, truncations, _ = self._environment.step(actions)

# Convert observations
observations = self._convert_observations(observations)

# Global state
env_state = self._create_state_representation(observations)

# Extra infos
info = {"state": env_state, "legals": self._legals}

return observations, rewards, terminals, truncations, info

def _create_state_representation(self, observations: Observations, first: bool = False) -> Any:
if first:
self._state_history = np.zeros((84, 145, 4), "float32")

state = np.expand_dims(observations["paddle_0"][:, :], axis=-1)

# framestacking
self._state_history = np.concatenate((state, self._state_history[:, :, :3]), axis=-1)

return self._state_history

def _convert_observations(self, observations: List) -> Observations:
"""Make observations partial."""
processed_observations = {}
for agent in self._agents:
if agent == "paddle_0":
agent_obs = observations[agent][:, :110] # hide the other agent
else:
agent_obs = observations[agent][:, 35:] # hide the other agent

agent_obs = np.expand_dims(agent_obs, axis=-1)
processed_observations[agent] = agent_obs

return processed_observations

def __getattr__(self, name: str) -> Any:
"""Expose any other attributes of the underlying environment.

Args:
name (str): attribute.

Returns:
Any: return attribute from env or underlying env.
"""
if hasattr(self.__class__, name):
return self.__getattribute__(name)
else:
return getattr(self._environment, name)
2 changes: 1 addition & 1 deletion og_marl/environments/pettingzoo_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def __init__(self) -> None:
def reset(self) -> ResetReturn:
"""Resets the env."""
# Reset the environment
observations = self._environment.reset() # type: ignore
observations, _ = self._environment.reset() # type: ignore

# Global state
env_state = self._create_state_representation(observations)
Expand Down
Loading
Loading