Skip to content

Commit

Permalink
Merge branch 'main' into feat/mkdocstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
callumtilbury committed Mar 11, 2024
2 parents c9ffdad + bfeb79e commit 1fabc48
Show file tree
Hide file tree
Showing 13 changed files with 455 additions and 124 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ vaults_unprocessed
development
SMAC_Maps
logs
development
__MACOSX
3.9

Expand Down
31 changes: 8 additions & 23 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,19 +59,24 @@ Download environment dependencies. We will use SMACv1 in this example.

Download a dataset.

`python examples/download_vault.py --env=smac_v1 --scenario=3m`
`python examples/download_dataset.py --env=smac_v1 --scenario=3m`

Run a baseline. In this example we will run MAICQ.

`python baselines/main.py --env=smac_v1 --scenario=3m --dataset=Good --system=maicq`

## Dataset API

We provide a simple demonstrative notebook of how to use OG-MARL's dataset API here:

[![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/instadeepai/og-marl/blob/main/examples/dataset_api_demo.ipynb)

## Datasets 🎥

We have generated datasets on a diverse set of popular MARL environments. A list of currently supported environments is included in the table below. It is well known from the single-agent offline RL literature that the quality of experience in offline datasets can play a large role in the final performance of offline RL algorithms. Therefore in OG-MARL, for each environment and scenario, we include a range of dataset distributions including `Good`, `Medium`, `Poor` and `Replay` datasets in order to benchmark offline MARL algorithms on a range of different dataset qualities. For more information on why we chose to include each environment and its task properties, please read our accompanying [paper](https://arxiv.org/abs/2302.00521).

<div class="collage">
<div class="row" align="center">
<!-- <img src="docs/assets/smac.png" alt="SMAC v1" width="16%"/> -->
<img src="docs/assets/smacv2.png" alt="SMAC v2" width="16%"/>
<img src="docs/assets/pistonball.png" alt="Pistonball" width="16%"/>
<img src="docs/assets/coop_pong.png" alt="Cooperative Pong" width="16%"/>
Expand Down Expand Up @@ -109,15 +114,8 @@ We are in the process of migrating our datasets from TF Records to Flashbax Vaul
| 🔌Voltage Control | case33_3min_final | 6 | Cont. | Vector | Dense | Homog | [source](https://github.com/Future-Power-Networks/MAPDN) |
| 🔴MPE | simple_adversary | 3 | Discrete. | Vector | Dense | Competitive | [source](https://pettingzoo.farama.org/environments/mpe/simple_adversary/) |

## Dataset API

We provide a simple demonstrative notebook of how to use OG-MARL's dataset API here:

[![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/instadeepai/og-marl/blob/main/examples/dataset_api_demo.ipynb)


### Dataset and Vault Locations
For OG-MARL's systems, we require the following dataset storage structure:
For OG-MARL's systems, we require the following dataset file structure:

```
examples/
Expand All @@ -137,19 +135,6 @@ vaults/
| |_> Medium/
| |_> Poor/
|_> ...
datasets/
|_> smac_v1/
|_> 3m/
| |_> Good/
| |_> Medium/
| |_> Poor/
|_> ...
|_> smac_v2/
|_> terran_5_vs_5/
| |_> Good/
| |_> Medium/
| |_> Poor/
|_> ...
...
```

Expand Down
26 changes: 13 additions & 13 deletions examples/download_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,21 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from absl import app, flags

from og_marl.environments import smacv1
from og_marl.offline_dataset import OfflineMARLDataset, download_and_unzip_dataset
from og_marl.offline_dataset import download_and_unzip_vault

# Comment this out if you already downloaded the dataset
download_and_unzip_dataset("smac_v1", "3m", dataset_base_dir="datasets")
FLAGS = flags.FLAGS
flags.DEFINE_string("env", "smac_v1", "Environment name.")
flags.DEFINE_string("scenario", "3m", "Environment scenario name.")

# Compute mean episode return of Good dataset

env = smacv1.SMACv1("3m") # Change SMAC Scenario Here
dataset = OfflineMARLDataset(env, "datasets/smac_v1/3m/Good")
def main(_):
# Download vault
download_and_unzip_vault(FLAGS.env_name, FLAGS.scenario_name)

sample_cnt = 0
tot_returns = 0
for sample in dataset._tf_dataset:
sample_cnt += 1
tot_returns += sample["episode_return"].numpy()
print("Mean Episode return:", tot_returns / sample_cnt)
# NEXT STEPS: See `examples/dataset_api_demo.ipynb`


if __name__ == "__main__":
app.run(main)
28 changes: 28 additions & 0 deletions examples/tf2/online/idrqn_smax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Copyright 2023 InstaDeep Ltd. All rights reserved.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from og_marl.environments.jaxmarl_smax import SMAX
from og_marl.loggers import WandbLogger
from og_marl.replay_buffers import FlashbaxReplayBuffer
from og_marl.tf2.systems.qmix import QMIXSystem

env = SMAX("3m")

logger = WandbLogger(entity="claude_formanek")

system = QMIXSystem(env, logger, eps_decay_timesteps=50_000)

replay_buffer = FlashbaxReplayBuffer(sequence_length=20)

system.train_online(replay_buffer)
2 changes: 1 addition & 1 deletion examples/tf2/run_all_baselines.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os

from og_marl.environments.utils import get_environment
from og_marl.environments import get_environment
from og_marl.loggers import JsonWriter, WandbLogger
from og_marl.replay_buffers import FlashbaxReplayBuffer
from og_marl.tf2.systems import get_system
Expand Down
33 changes: 0 additions & 33 deletions manifest.yaml

This file was deleted.

59 changes: 59 additions & 0 deletions og_marl/environments/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# type: ignore

# Copyright 2023 InstaDeep Ltd. All rights reserved.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from og_marl.environments.base import BaseEnvironment


def get_environment(env_name: str, scenario: str) -> BaseEnvironment:
if env_name == "smac_v1":
from og_marl.environments.smacv1 import SMACv1

return SMACv1(scenario)
elif env_name == "smac_v2":
from og_marl.environments.smacv2 import SMACv2

return SMACv2(scenario)
elif env_name == "mamujoco":
from og_marl.environments.old_mamujoco import MAMuJoCo

return MAMuJoCo(scenario)
elif env_name == "gymnasium_mamujoco":
from og_marl.environments.gymnasium_mamujoco import MAMuJoCo

return MAMuJoCo(scenario)
elif env_name == "flatland":
from og_marl.environments.flatland_wrapper import Flatland

return Flatland(scenario)
elif env_name == "voltage_control":
from og_marl.environments.voltage_control import VoltageControlEnv

return VoltageControlEnv()
elif env_name == "smax":
from og_marl.environments.jaxmarl_smax import SMAX

return SMAX(scenario)
elif env_name == "lbf":
from og_marl.environments.jumanji_lbf import JumanjiLBF

return JumanjiLBF(scenario)
elif env_name == "rware":
from og_marl.environments.jumanji_rware import JumanjiRware

return JumanjiRware(scenario)
else:
raise ValueError("Environment not recognised.")
98 changes: 98 additions & 0 deletions og_marl/environments/jaxmarl_smax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# Copyright 2023 InstaDeep Ltd. All rights reserved.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Base wrapper for Jumanji environments."""
from typing import Any, Dict

import jax
import numpy as np
from jaxmarl import make
from jaxmarl.environments.smax import map_name_to_scenario
from og_marl.environments.base import BaseEnvironment, ResetReturn, StepReturn


class SMAX(BaseEnvironment):

"""Environment wrapper for Jumanji environments."""

def __init__(self, scenario_name: str = "3m", seed: int = 0) -> None:
"""Constructor."""
scenario = map_name_to_scenario(scenario_name)

self._environment = make(
"HeuristicEnemySMAX",
enemy_shoots=True,
scenario=scenario,
use_self_play_reward=False,
walls_cause_death=True,
see_enemy_actions=False,
)

self._num_agents = self._environment.num_agents
self.possible_agents = self._environment.agents
self._num_actions = int(self._environment.action_spaces[self.possible_agents[0]].n)

self._state = ... # Jaxmarl environment state

self.info_spec: Dict[str, Any] = {} # TODO add global state spec

self._key = jax.random.PRNGKey(seed)

self._env_step = jax.jit(self._environment.step)

def reset(self) -> ResetReturn:
"""Resets the env."""
# Reset the environment
self._key, sub_key = jax.random.split(self._key)
obs, self._state = self._environment.reset(sub_key)

observations = {
agent: np.asarray(obs[agent], dtype=np.float32) for agent in self.possible_agents
}
legals = {
agent: np.array(legal, "int64")
for agent, legal in self._environment.get_avail_actions(self._state).items()
}
state = np.asarray(obs["world_state"], "float32")

# Infos
info = {"legals": legals, "state": state}

return observations, info

def step(self, actions: Dict[str, np.ndarray]) -> StepReturn:
"""Steps in env."""
self._key, sub_key = jax.random.split(self._key)

# Step the environment
obs, self._state, reward, done, infos = self._environment.step(
sub_key, self._state, actions
)

observations = {
agent: np.asarray(obs[agent], dtype=np.float32) for agent in self.possible_agents
}
legals = {
agent: np.array(legal, "int64")
for agent, legal in self._environment.get_avail_actions(self._state).items()
}
state = np.asarray(obs["world_state"], "float32")

# Infos
info = {"legals": legals, "state": state}

rewards = {agent: reward[agent] for agent in self.possible_agents}
terminals = {agent: done["__all__"] for agent in self.possible_agents}
truncations = {agent: False for agent in self.possible_agents}

return observations, rewards, terminals, truncations, info
Loading

0 comments on commit 1fabc48

Please sign in to comment.