diff --git a/Dockerfile b/Dockerfile index efcfa61b..4b44b989 100644 --- a/Dockerfile +++ b/Dockerfile @@ -2,6 +2,7 @@ FROM nvidia/cuda:11.8.0-cudnn8-runtime-ubuntu22.04 # Ensure no installs try to launch interactive screen ARG DEBIAN_FRONTEND=noninteractive +ENV PYTHONUNBUFFERED=1 # Update packages and install python3.9 and other dependencies RUN apt-get update -y && \ @@ -33,14 +34,18 @@ RUN pip install --quiet --upgrade pip setuptools wheel && \ pip install -e . && \ pip install flashbax==0.1.0 -ENV SC2PATH /home/app/StarCraftII +# ENV SC2PATH /home/app/StarCraftII # RUN ./install_environments/smacv1.sh -RUN ./install_environments/smacv2.sh +# RUN ./install_environments/smacv2.sh # ENV LD_LIBRARY_PATH $LD_LIBRARY_PATH:/root/.mujoco/mujoco210/bin:/usr/lib/nvidia # ENV SUPPRESS_GR_PROMPT 1 # RUN ./install_environments/mamujoco.sh +RUN ./install_environments/pettingzoo.sh + +# RUN ./install_environments/flatland.sh + # Copy all code COPY ./examples ./examples COPY ./baselines ./baselines diff --git a/README.md b/README.md index 92ee6825..201306d7 100644 --- a/README.md +++ b/README.md @@ -104,12 +104,12 @@ We are in the process of migrating our datasets from TF Records to Flashbax Vaul | 💣SMAC v2 | terran_5_vs_5
zerg_5_vs_5
terran_10_vs_10 | 5
5
10 | Discrete | Vector | Dense | Heterog | [source](https://github.com/oxwhirl/smacv2) | | 🚅Flatland | 3 Trains
5 Trains | 3
5 | Discrete | Vector | Sparse | Homog | [source](https://flatland.aicrowd.com/intro.html) | | 🐜MAMuJoCo | 2-HalfCheetah
2-Ant
4-Ant | 2
2
4 | Cont. | Vector | Dense | Heterog
Homog
Homog | [source](https://github.com/schroederdewitt/multiagent_mujoco) | - +| 🐻PettingZoo | Pursuit
Co-op Pong | 8
2 | Discrete
Discrete | Pixels
Pixels | Dense | Homog
Heterog | [source](https://pettingzoo.farama.org/) | ### Legacy Datasets (still to be migrated to Vault) 👴 | Environment | Scenario | Agents | Act | Obs | Reward | Types | Repo | |-----|----|----|-----|-----|----|----|-----| -| 🐻PettingZoo | Pursuit
Co-op Pong
PistonBall
KAZ| 8
2
15
2| Discrete
Discrete
Cont.
Discrete | Pixels
Pixels
Pixels
Vector | Dense | Homog
Heterog
Homog
Heterog| [source](https://pettingzoo.farama.org/) | +| 🐻PettingZoo | PistonBall
KAZ| 15
2| Cont.
Discrete | Pixels
Vector | Dense | Homog
Heterog| [source](https://pettingzoo.farama.org/) | | 🏙️CityLearn | 2022_all_phases | 17 | Cont. | Vector | Dense | Homog | [source](https://github.com/intelligent-environments-lab/CityLearn) | | 🔌Voltage Control | case33_3min_final | 6 | Cont. | Vector | Dense | Homog | [source](https://github.com/Future-Power-Networks/MAPDN) | | 🔴MPE | simple_adversary | 3 | Discrete. | Vector | Dense | Competitive | [source](https://pettingzoo.farama.org/environments/mpe/simple_adversary/) | diff --git a/baselines/main.py b/baselines/main.py index 1fdb8be0..0f439fc4 100644 --- a/baselines/main.py +++ b/baselines/main.py @@ -13,20 +13,21 @@ # limitations under the License. from absl import app, flags -from og_marl.environments.utils import get_environment +from og_marl.environments import get_environment from og_marl.loggers import JsonWriter, WandbLogger from og_marl.offline_dataset import download_and_unzip_vault from og_marl.replay_buffers import FlashbaxReplayBuffer +from og_marl.tf2.networks import CNNEmbeddingNetwork from og_marl.tf2.systems import get_system from og_marl.tf2.utils import set_growing_gpu_memory set_growing_gpu_memory() FLAGS = flags.FLAGS -flags.DEFINE_string("env", "smac_v1", "Environment name.") -flags.DEFINE_string("scenario", "3m", "Environment scenario name.") +flags.DEFINE_string("env", "pettingzoo", "Environment name.") +flags.DEFINE_string("scenario", "pursuit", "Environment scenario name.") flags.DEFINE_string("dataset", "Good", "Dataset type.: 'Good', 'Medium', 'Poor' or 'Replay' ") -flags.DEFINE_string("system", "dbc", "System name.") +flags.DEFINE_string("system", "qmix", "System name.") flags.DEFINE_integer("seed", 42, "Seed.") flags.DEFINE_float("trainer_steps", 5e4, "Number of training steps.") flags.DEFINE_integer("batch_size", 64, "Number of training steps.") @@ -43,7 +44,7 @@ def main(_): env = get_environment(FLAGS.env, FLAGS.scenario) - buffer = FlashbaxReplayBuffer(sequence_length=20, sample_period=2) + buffer = FlashbaxReplayBuffer(sequence_length=20, sample_period=1) download_and_unzip_vault(FLAGS.env, FLAGS.scenario) @@ -65,6 +66,9 @@ def main(_): ) system_kwargs = {"add_agent_id_to_obs": True} + if FLAGS.scenario == "pursuit": + system_kwargs["observation_embedding_network"] = CNNEmbeddingNetwork() + system = get_system(FLAGS.system, env, logger, **system_kwargs) system.train_offline(buffer, max_trainer_steps=FLAGS.trainer_steps, json_writer=json_writer) diff --git a/examples/tf2/run_all_baselines.py b/examples/tf2/run_all_baselines.py index 51caece7..957a2b71 100644 --- a/examples/tf2/run_all_baselines.py +++ b/examples/tf2/run_all_baselines.py @@ -1,32 +1,45 @@ import os +# import module +import traceback + from og_marl.environments import get_environment -from og_marl.loggers import JsonWriter, WandbLogger +from og_marl.loggers import TerminalLogger, JsonWriter from og_marl.replay_buffers import FlashbaxReplayBuffer +from og_marl.tf2.networks import CNNEmbeddingNetwork from og_marl.tf2.systems import get_system from og_marl.tf2.utils import set_growing_gpu_memory set_growing_gpu_memory() -os.environ["SUPPRESS_GR_PROMPT"] = 1 +# For MAMuJoCo +os.environ["SUPPRESS_GR_PROMPT"] = "1" scenario_system_configs = { "smac_v1": { "3m": { - "systems": ["idrqn", "idrqn+cql", "idrqn+bcq", "qmix+cql", "qmix+bcq", "maicq"], + "systems": ["idrqn", "idrqn+cql", "idrqn+bcq", "qmix+cql", "qmix+bcq", "maicq", "dbc"], "datasets": ["Good"], - "trainer_steps": 3000, + "trainer_steps": 2000, "evaluate_every": 1000, }, }, - "mamujoco": { - "2halfcheetah": { - "systems": ["iddpg", "iddpg+cql", "maddpg+cql", "maddpg", "omar"], + "pettingzoo": { + "pursuit": { + "systems": ["idrqn", "idrqn+cql", "idrqn+bcq", "qmix+cql", "qmix+bcq", "maicq", "dbc"], "datasets": ["Good"], - "trainer_steps": 3000, + "trainer_steps": 2000, "evaluate_every": 1000, }, }, + # "mamujoco": { + # "2halfcheetah": { + # "systems": ["iddpg", "iddpg+cql", "maddpg+cql", "maddpg", "omar"], + # "datasets": ["Good"], + # "trainer_steps": 3000, + # "evaluate_every": 1000, + # }, + # }, } seeds = [42] @@ -44,7 +57,7 @@ "system": env_name, "seed": seed, } - logger = WandbLogger(config, project="og-marl-baselines") + logger = TerminalLogger() env = get_environment(env_name, scenario_name) buffer = FlashbaxReplayBuffer(sequence_length=20, sample_period=1) @@ -55,10 +68,18 @@ raise ValueError("Vault not found. Exiting.") json_writer = JsonWriter( - "logs", system_name, f"{scenario_name}_{dataset_name}", env_name, seed + "test_all_baselines", + system_name, + f"{scenario_name}_{dataset_name}", + env_name, + seed, ) system_kwargs = {"add_agent_id_to_obs": True} + + if scenario_name == "pursuit": + system_kwargs["observation_embedding_network"] = CNNEmbeddingNetwork() + system = get_system(system_name, env, logger, **system_kwargs) trainer_steps = scenario_system_configs[env_name][scenario_name][ @@ -75,7 +96,7 @@ ) except: # noqa: E722 logger.close() - print() - print("BROKEN") + print("BROKEN:", env_name, scenario_name, system_name) + traceback.print_exc() print() continue diff --git a/install_environments/requirements/pettingzoo.txt b/install_environments/requirements/pettingzoo.txt index a0dc1340..151bd05e 100755 --- a/install_environments/requirements/pettingzoo.txt +++ b/install_environments/requirements/pettingzoo.txt @@ -2,7 +2,7 @@ autorom gym numpy opencv-python -pettingzoo==1.22.0 +pettingzoo==1.23.1 pygame pymunk scipy diff --git a/mkdocs.yml b/mkdocs.yml index a8bccadc..75ace30d 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -38,7 +38,6 @@ theme: nav: - Home: 'index.md' - - Datasets: 'datasets.md' - Baseline Results: 'baselines.md' - Updates: 'updates.md' - API Reference: 'api.md' diff --git a/og_marl/environments/__init__.py b/og_marl/environments/__init__.py index c1439ff2..846c161c 100644 --- a/og_marl/environments/__init__.py +++ b/og_marl/environments/__init__.py @@ -18,7 +18,7 @@ from og_marl.environments.base import BaseEnvironment -def get_environment(env_name: str, scenario: str) -> BaseEnvironment: +def get_environment(env_name: str, scenario: str) -> BaseEnvironment: # noqa: C901 if env_name == "smac_v1": from og_marl.environments.smacv1 import SMACv1 @@ -31,6 +31,14 @@ def get_environment(env_name: str, scenario: str) -> BaseEnvironment: from og_marl.environments.old_mamujoco import MAMuJoCo return MAMuJoCo(scenario) + elif scenario == "pursuit": + from og_marl.environments.pursuit import Pursuit + + return Pursuit() + elif scenario == "coop_pong": + from og_marl.environments.coop_pong import CooperativePong + + return CooperativePong() elif env_name == "gymnasium_mamujoco": from og_marl.environments.gymnasium_mamujoco import MAMuJoCo diff --git a/og_marl/environments/coop_pong.py b/og_marl/environments/coop_pong.py new file mode 100644 index 00000000..0c259f26 --- /dev/null +++ b/og_marl/environments/coop_pong.py @@ -0,0 +1,114 @@ +# python3 +# Copyright 2021 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Wrapper for Cooperative Pettingzoo environments.""" +from typing import Any, List, Dict + +import numpy as np +from pettingzoo.butterfly import cooperative_pong_v5 +import supersuit + +from og_marl.environments.base import BaseEnvironment +from og_marl.environments.base import Observations, ResetReturn, StepReturn + + +class CooperativePong(BaseEnvironment): + """Environment wrapper PettingZoo Cooperative Pong.""" + + def __init__( + self, + ) -> None: + """Constructor.""" + self._environment = cooperative_pong_v5.parallel_env(render_mode="rgb_array") + # Wrap environment with supersuit pre-process wrappers + self._environment = supersuit.color_reduction_v0(self._environment, mode="R") + self._environment = supersuit.resize_v0(self._environment, x_size=145, y_size=84) + self._environment = supersuit.dtype_v0(self._environment, dtype="float32") + self._environment = supersuit.normalize_obs_v0(self._environment) + + self._agents = self._environment.possible_agents + self._done = False + self.max_episode_length = 900 + + def reset(self) -> ResetReturn: + """Resets the env.""" + # Reset the environment + observations, _ = self._environment.reset() # type: ignore + + # Convert observations + observations = self._convert_observations(observations) + + # Global state + env_state = self._create_state_representation(observations, first=True) + + # Infos + info = {"state": env_state, "legals": self._legals} + + return observations, info + + def step(self, actions: Dict[str, np.ndarray]) -> StepReturn: + """Steps in env.""" + # Step the environment + observations, rewards, terminals, truncations, _ = self._environment.step(actions) + + # Convert observations + observations = self._convert_observations(observations) + + # Global state + env_state = self._create_state_representation(observations) + + # Extra infos + info = {"state": env_state, "legals": self._legals} + + return observations, rewards, terminals, truncations, info + + def _create_state_representation(self, observations: Observations, first: bool = False) -> Any: + if first: + self._state_history = np.zeros((84, 145, 4), "float32") + + state = np.expand_dims(observations["paddle_0"][:, :], axis=-1) + + # framestacking + self._state_history = np.concatenate((state, self._state_history[:, :, :3]), axis=-1) + + return self._state_history + + def _convert_observations(self, observations: List) -> Observations: + """Make observations partial.""" + processed_observations = {} + for agent in self._agents: + if agent == "paddle_0": + agent_obs = observations[agent][:, :110] # hide the other agent + else: + agent_obs = observations[agent][:, 35:] # hide the other agent + + agent_obs = np.expand_dims(agent_obs, axis=-1) + processed_observations[agent] = agent_obs + + return processed_observations + + def __getattr__(self, name: str) -> Any: + """Expose any other attributes of the underlying environment. + + Args: + name (str): attribute. + + Returns: + Any: return attribute from env or underlying env. + """ + if hasattr(self.__class__, name): + return self.__getattribute__(name) + else: + return getattr(self._environment, name) diff --git a/og_marl/environments/pettingzoo_base.py b/og_marl/environments/pettingzoo_base.py index edbfd13d..bbbef7c4 100644 --- a/og_marl/environments/pettingzoo_base.py +++ b/og_marl/environments/pettingzoo_base.py @@ -30,7 +30,7 @@ def __init__(self) -> None: def reset(self) -> ResetReturn: """Resets the env.""" # Reset the environment - observations = self._environment.reset() # type: ignore + observations, _ = self._environment.reset() # type: ignore # Global state env_state = self._create_state_representation(observations) diff --git a/og_marl/environments/pursuit.py b/og_marl/environments/pursuit.py index e6bb028e..9050cb41 100644 --- a/og_marl/environments/pursuit.py +++ b/og_marl/environments/pursuit.py @@ -11,24 +11,22 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +from typing import Dict import numpy as np from gymnasium.spaces import Box, Discrete from pettingzoo.sisl import pursuit_v4 -from supersuit import black_death_v3 -from og_marl.environments.base import Observations -from og_marl.environments.pettingzoo_base import PettingZooBase +from og_marl.environments.base import BaseEnvironment, Observations, ResetReturn, StepReturn -class Pursuit(PettingZooBase): +class Pursuit(BaseEnvironment): """Environment wrapper for Pursuit.""" def __init__(self) -> None: """Constructor for Pursuit""" - self._environment = black_death_v3(pursuit_v4.parallel_env()) + self._environment = pursuit_v4.parallel_env() self.possible_agents = self._environment.possible_agents self._num_actions = 5 self._obs_dim = (7, 7, 3) @@ -38,11 +36,39 @@ def __init__(self) -> None: agent: Box(-np.inf, np.inf, (*self._obs_dim,)) for agent in self.possible_agents } - self.info_spec = {"state": np.zeros(8 * 2 + 30 * 2, "float32")} + self._legals = { + agent: np.ones((self._num_actions,), "int32") for agent in self.possible_agents + } + + self.info_spec = {"state": np.zeros(8 * 2 + 30 * 2, "float32"), "legals": self._legals} + + self.max_episode_length = 500 + + def reset(self) -> ResetReturn: + """Resets the env.""" + # Reset the environment + observations, _ = self._environment.reset() # type: ignore + + # Global state + env_state = self._create_state_representation(observations) + + # Infos + info = {"state": env_state, "legals": self._legals} + + return observations, info + + def step(self, actions: Dict[str, np.ndarray]) -> StepReturn: + """Steps in env.""" + # Step the environment + observations, rewards, terminals, truncations, _ = self._environment.step(actions) + + # Global state + env_state = self._create_state_representation(observations) + + # Extra infos + info = {"state": env_state, "legals": self._legals} - def _convert_observations(self, observations: Observations, done: bool) -> Observations: - """Convert observations.""" - return observations + return observations, rewards, terminals, truncations, info def _create_state_representation(self, observations: Observations) -> np.ndarray: pursuer_pos = [ diff --git a/og_marl/offline_dataset.py b/og_marl/offline_dataset.py index 300734ee..d209a07c 100644 --- a/og_marl/offline_dataset.py +++ b/og_marl/offline_dataset.py @@ -15,22 +15,17 @@ import os import sys import zipfile -from pathlib import Path -from typing import Any, Dict, List +from typing import Dict, List import jax import jax.numpy as jnp import matplotlib.pyplot as plt import requests # type: ignore import seaborn as sns -import tensorflow as tf -import tree from chex import Array from flashbax.vault import Vault from git import Optional -from tensorflow import DType -from og_marl.environments.base import BaseEnvironment VAULT_INFO = { "smac_v1": { @@ -64,236 +59,15 @@ "4ant": {"url": "https://s3.kao.instadeep.io/offline-marl-dataset/vaults/4ant.zip"}, }, "flatland": { - "5trains": {"url": "https://s3.kao.instadeep.io/offline-marl-dataset/vaults/3trains.zip"}, - "2trains": {"url": "https://s3.kao.instadeep.io/offline-marl-dataset/vaults/5trains.zip"}, + "3trains": {"url": "https://s3.kao.instadeep.io/offline-marl-dataset/vaults/3trains.zip"}, + "5trains": {"url": "https://s3.kao.instadeep.io/offline-marl-dataset/vaults/5trains.zip"}, }, -} - -DATASET_INFO = { - "smac_v1": { - "3m": {"url": "https://tinyurl.com/3m-dataset", "sequence_length": 20, "period": 10}, - "8m": {"url": "https://tinyurl.com/8m-dataset", "sequence_length": 20, "period": 10}, - "5m_vs_6m": { - "url": "https://tinyurl.com/5m-vs-6m-dataset", - "sequence_length": 20, - "period": 10, - }, - "2s3z": {"url": "https://tinyurl.com/2s3z-dataset", "sequence_length": 20, "period": 10}, - "3s5z_vs_3s6z": { - "url": "https://tinyurl.com/3s5z-vs-3s6z-dataset3", - "sequence_length": 20, - "period": 10, - }, - "2c_vs_64zg": { - "url": "https://tinyurl.com/2c-vs-64zg-dataset", - "sequence_length": 20, - "period": 10, - }, - "27m_vs_30m": { - "url": "https://tinyurl.com/27m-vs-30m-dataset", - "sequence_length": 20, - "period": 10, - }, - }, - "smac_v2": { - "terran_5_vs_5": { - "url": "https://tinyurl.com/terran-5-vs-5-dataset", - "sequence_length": 20, - "period": 10, - }, - "zerg_5_vs_5": { - "url": "https://tinyurl.com/zerg-5-vs-5-dataset", - "sequence_length": 20, - "period": 10, - }, - "terran_10_vs_10": { - "url": "https://tinyurl.com/terran-10-vs-10-dataset", - "sequence_length": 20, - "period": 10, - }, - }, - "flatland": { - "3trains": { - "url": "https://tinyurl.com/3trains-dataset", - "sequence_length": 20, # TODO - "period": 10, - }, - "5trains": { - "url": "https://tinyurl.com/5trains-dataset", - "sequence_length": 20, # TODO - "period": 10, - }, - }, - "mamujoco": { - "2halfcheetah": { - "url": "https://tinyurl.com/2halfcheetah-dataset", - "sequence_length": 20, - "period": 10, - }, - "2ant": {"url": "https://tinyurl.com/2ant-dataset", "sequence_length": 20, "period": 10}, - "4ant": {"url": "https://tinyurl.com/4ant-dataset", "sequence_length": 20, "period": 10}, - }, - "voltage_control": { - "case33_3min_final": { - "url": "https://tinyurl.com/case33-3min-final-dataset", - "sequence_length": 20, - "period": 10, - }, + "pettingzoo": { + "pursuit": {"url": "https://s3.kao.instadeep.io/offline-marl-dataset/vaults/pursuit.zip"} }, } -def get_schema_dtypes(environment: BaseEnvironment) -> Dict[str, DType]: - act_type = list(environment.action_spaces.values())[0].dtype - schema = {} - for agent in environment.possible_agents: - schema[agent + "_observations"] = tf.float32 - schema[agent + "_legal_actions"] = tf.float32 - schema[agent + "_actions"] = act_type - schema[agent + "_rewards"] = tf.float32 - schema[agent + "_discounts"] = tf.float32 - - ## Extras - # Zero-padding mask - schema["zero_padding_mask"] = tf.float32 - - # Env state - schema["env_state"] = tf.float32 - - # Episode return - schema["episode_return"] = tf.float32 - - return schema - - -class OfflineMARLDataset: - def __init__( - self, - environment: BaseEnvironment, - env_name: str, - scenario_name: str, - dataset_type: str, - base_dataset_dir: str = "./datasets", - ): - self._environment = environment - self._schema = get_schema_dtypes(environment) - self._agents = environment.possible_agents - - path_to_dataset = f"{base_dataset_dir}/{env_name}/{scenario_name}/{dataset_type}" - - file_path = Path(path_to_dataset) - sub_dir_to_idx = {} - idx = 0 - for subdir in os.listdir(file_path): - if file_path.joinpath(subdir).is_dir(): - sub_dir_to_idx[subdir] = idx - idx += 1 - - def get_fname_idx(file_name: str) -> int: - dir_idx = sub_dir_to_idx[file_name.split("/")[-2]] * 1000 - return dir_idx + int(file_name.split("log_")[-1].split(".")[0]) - - filenames = [str(file_name) for file_name in file_path.glob("**/*.tfrecord")] - filenames = sorted(filenames, key=get_fname_idx) - - filename_dataset = tf.data.Dataset.from_tensor_slices(filenames) - self.raw_dataset = filename_dataset.flat_map( - lambda x: tf.data.TFRecordDataset(x, compression_type="GZIP").map(self._decode_fn) - ) - - self.period = DATASET_INFO[env_name][scenario_name]["period"] - self.sequence_length = DATASET_INFO[env_name][scenario_name]["sequence_length"] - self.max_episode_length = environment.max_episode_length - - def _decode_fn(self, record_bytes: Any) -> Dict[str, Any]: - example = tf.io.parse_single_example( - record_bytes, - tree.map_structure(lambda x: tf.io.FixedLenFeature([], dtype=tf.string), self._schema), - ) - - for key, dtype in self._schema.items(): - example[key] = tf.io.parse_tensor(example[key], dtype) - - sample: Dict[str, dict] = { - "observations": {}, - "actions": {}, - "rewards": {}, - "terminals": {}, - "truncations": {}, - "infos": {"legals": {}}, - } - for agent in self._agents: - sample["observations"][agent] = example[f"{agent}_observations"] - sample["actions"][agent] = example[f"{agent}_actions"] - sample["rewards"][agent] = example[f"{agent}_rewards"] - sample["terminals"][agent] = 1 - example[f"{agent}_discounts"] - sample["truncations"][agent] = tf.zeros_like(example[f"{agent}_discounts"]) - sample["infos"]["legals"][agent] = example[f"{agent}_legal_actions"] - - sample["infos"]["mask"] = example["zero_padding_mask"] - sample["infos"]["state"] = example["env_state"] - sample["infos"]["episode_return"] = example["episode_return"] - - return sample - - def __getattr__(self, name: Any) -> Any: - """Expose any other attributes of the underlying environment. - - Args: - ---- - name (str): attribute. - - Returns: - ------- - Any: return attribute from env or underlying env. - - """ - if hasattr(self.__class__, name): - return self.__getattribute__(name) - else: - return getattr(self._tf_dataset, name) - - -def download_and_unzip_dataset( - env_name: str, - scenario_name: str, - dataset_base_dir: str = "./datasets", -) -> None: - dataset_download_url = DATASET_INFO[env_name][scenario_name]["url"] - - # TODO add check to see if dataset exists already. - - os.makedirs(f"{dataset_base_dir}/tmp/", exist_ok=True) - os.makedirs(f"{dataset_base_dir}/{env_name}/", exist_ok=True) - - zip_file_path = f"{dataset_base_dir}/tmp/tmp_dataset.zip" - - extraction_path = f"{dataset_base_dir}/{env_name}" - - response = requests.get(dataset_download_url, stream=True) # type: ignore - total_length = response.headers.get("content-length") - - with open(zip_file_path, "wb") as file: - if total_length is None: # no content length header - file.write(response.content) - else: - dl = 0 - total_length = int(total_length) # type: ignore - for data in response.iter_content(chunk_size=4096): - dl += len(data) - file.write(data) - done = int(50 * dl / total_length) # type: ignore - sys.stdout.write("\r[%s%s]" % ("=" * done, " " * (50 - done))) - sys.stdout.flush() - - # Step 2: Unzip the file - with zipfile.ZipFile(zip_file_path, "r") as zip_ref: - zip_ref.extractall(extraction_path) - - # Optionally, delete the zip file after extraction - os.remove(zip_file_path) - - def download_and_unzip_vault( env_name: str, scenario_name: str, diff --git a/og_marl/replay_buffers.py b/og_marl/replay_buffers.py index 328e1220..dcd64325 100644 --- a/og_marl/replay_buffers.py +++ b/og_marl/replay_buffers.py @@ -97,14 +97,13 @@ def populate_from_vault( ).read() # Recreate the buffer and associated pure functions - self._max_size = self._buffer_state.current_index self._replay_buffer = fbx.make_trajectory_buffer( add_batch_size=1, sample_batch_size=self._batch_size, sample_sequence_length=self._sequence_length, period=1, min_length_time_axis=1, - max_size=self._max_size, + max_size=self._sequence_length, ) self._buffer_sample_fn = jax.jit(self._replay_buffer.sample) self._buffer_add_fn = jax.jit(self._replay_buffer.add) diff --git a/og_marl/tf2/networks.py b/og_marl/tf2/networks.py new file mode 100644 index 00000000..2f219ff1 --- /dev/null +++ b/og_marl/tf2/networks.py @@ -0,0 +1,142 @@ +from typing import Sequence + +import tensorflow as tf +from tensorflow import Tensor +import sonnet as snt + + +class QMixer(snt.Module): + + """QMIX mixing network.""" + + def __init__( + self, + num_agents: int, + embed_dim: int = 32, + hypernet_embed: int = 64, + non_monotonic: bool = False, + ): + """Initialise QMIX mixing network + + Args: + ---- + num_agents: Number of agents in the environment + state_dim: Dimensions of the global environment state + embed_dim: The dimension of the output of the first layer + of the mixer. + hypernet_embed: Number of units in the hyper network + + """ + super().__init__() + self.num_agents = num_agents + self.embed_dim = embed_dim + self.hypernet_embed = hypernet_embed + self._non_monotonic = non_monotonic + + self.hyper_w_1 = snt.Sequential( + [ + snt.Linear(self.hypernet_embed), + tf.nn.relu, + snt.Linear(self.embed_dim * self.num_agents), + ] + ) + + self.hyper_w_final = snt.Sequential( + [snt.Linear(self.hypernet_embed), tf.nn.relu, snt.Linear(self.embed_dim)] + ) + + # State dependent bias for hidden layer + self.hyper_b_1 = snt.Linear(self.embed_dim) + + # V(s) instead of a bias for the last layers + self.V = snt.Sequential([snt.Linear(self.embed_dim), tf.nn.relu, snt.Linear(1)]) + + def __call__(self, agent_qs: Tensor, states: Tensor) -> Tensor: + """Forward method.""" + B = agent_qs.shape[0] # batch size + state_dim = states.shape[2:] + + agent_qs = tf.reshape(agent_qs, (-1, 1, self.num_agents)) + + states = tf.reshape(states, (-1, *state_dim)) + + # First layer + w1 = self.hyper_w_1(states) + if not self._non_monotonic: + w1 = tf.abs(w1) + b1 = self.hyper_b_1(states) + w1 = tf.reshape(w1, (-1, self.num_agents, self.embed_dim)) + b1 = tf.reshape(b1, (-1, 1, self.embed_dim)) + hidden = tf.nn.elu(tf.matmul(agent_qs, w1) + b1) + + # Second layer + w_final = self.hyper_w_final(states) + if not self._non_monotonic: + w_final = tf.abs(w_final) + w_final = tf.reshape(w_final, (-1, self.embed_dim, 1)) + + # State-dependent bias + v = tf.reshape(self.V(states), (-1, 1, 1)) + + # Compute final output + y = tf.matmul(hidden, w_final) + v + + # Reshape and return + q_tot = tf.reshape(y, (B, -1, 1)) + + return q_tot + + def k(self, states: Tensor) -> Tensor: + """Method used by MAICQ.""" + B, T = states.shape[:2] + + w1 = tf.math.abs(self.hyper_w_1(states)) + w_final = tf.math.abs(self.hyper_w_final(states)) + w1 = tf.reshape(w1, shape=(-1, self.num_agents, self.embed_dim)) + w_final = tf.reshape(w_final, shape=(-1, self.embed_dim, 1)) + k = tf.matmul(w1, w_final) + k = tf.reshape(k, shape=(B, -1, self.num_agents)) + k = k / (tf.reduce_sum(k, axis=2, keepdims=True) + 1e-10) + return k + + +@snt.allow_empty_variables +class IdentityNetwork(snt.Module): + def __init__(self) -> None: + super().__init__() + return + + def __call__(self, x: Tensor) -> Tensor: + return x + + +class CNNEmbeddingNetwork(snt.Module): + def __init__( + self, output_channels: Sequence[int] = (8, 16), kernel_sizes: Sequence[int] = (3, 2) + ) -> None: + super().__init__() + assert len(output_channels) == len(kernel_sizes) + + layers = [] + for layer_i in range(len(output_channels)): + layers.append(snt.Conv2D(output_channels[layer_i], kernel_sizes[layer_i])) + layers.append(tf.nn.relu) + layers.append(tf.keras.layers.Flatten()) + + self.conv_net = snt.Sequential(layers) + + def __call__(self, x: Tensor) -> Tensor: + """Embed a pixel-styled input into a vector using a conv net. + + We assume the input has trailing dims + being the width, height and channel dimensions of the input. + + The output shape is then given as (B,T,N,Embed) + """ + leading_dims = x.shape[:-3] + trailing_dims = x.shape[-3:] # W,H,C + + x = tf.reshape(x, shape=(-1, *trailing_dims)) + embed = self.conv_net(x) + embed = tf.reshape(embed, shape=(*leading_dims, -1)) + return embed diff --git a/og_marl/tf2/systems/bc.py b/og_marl/tf2/systems/bc.py index 5f8f047b..f0523b46 100644 --- a/og_marl/tf2/systems/bc.py +++ b/og_marl/tf2/systems/bc.py @@ -19,6 +19,7 @@ switch_two_leading_dims, unroll_rnn, ) +from og_marl.tf2.networks import IdentityNetwork class DicreteActionBehaviourCloning(BaseMARLSystem): @@ -33,6 +34,7 @@ def __init__( discount: float = 0.99, learning_rate: float = 1e-3, add_agent_id_to_obs: bool = True, + observation_embedding_network: Optional[snt.Module] = None, ): super().__init__( environment, logger, discount=discount, add_agent_id_to_obs=add_agent_id_to_obs @@ -48,6 +50,9 @@ def __init__( snt.Linear(self._environment._num_actions), ] ) # shared network for all agents + if observation_embedding_network is None: + observation_embedding_network = IdentityNetwork() + self._policy_embedding_network = observation_embedding_network self._optimizer = snt.optimizers.RMSProp(learning_rate=learning_rate) @@ -101,14 +106,13 @@ def _tf_select_actions( agent_observation, i, len(self._environment.possible_agents) ) agent_observation = tf.expand_dims(agent_observation, axis=0) # add batch dimension - logits, next_rnn_states[agent] = self._policy_network( - agent_observation, rnn_states[agent] - ) + embedding = self._policy_embedding_network(agent_observation) + logits, next_rnn_states[agent] = self._policy_network(embedding, rnn_states[agent]) probs = tf.nn.softmax(logits) if legal_actions is not None: - agent_legals = tf.expand_dims(legal_actions[agent], axis=0) + agent_legals = tf.cast(tf.expand_dims(legal_actions[agent], axis=0), "float32") probs = (probs * agent_legals) / tf.reduce_sum( probs * agent_legals ) # mask and renorm @@ -116,7 +120,7 @@ def _tf_select_actions( action = tfp.distributions.Categorical(probs=probs).sample(1) # Store agent action - actions[agent] = action + actions[agent] = action[0] return actions, next_rnn_states @@ -147,11 +151,16 @@ def _tf_train_step(self, experience: Dict[str, Any]) -> Dict[str, Numeric]: resets = switch_two_leading_dims(resets) actions = switch_two_leading_dims(actions) + # Merge batch_dim and agent_dim + observations = merge_batch_and_agent_dim_of_time_major_sequence(observations) + resets = merge_batch_and_agent_dim_of_time_major_sequence(resets) + with tf.GradientTape() as tape: + embeddings = self._policy_embedding_network(observations) probs_out = unroll_rnn( self._policy_network, - merge_batch_and_agent_dim_of_time_major_sequence(observations), - merge_batch_and_agent_dim_of_time_major_sequence(resets), + embeddings, + resets, ) probs_out = expand_batch_and_agent_dim_of_time_major_sequence(probs_out, B, N) @@ -163,7 +172,10 @@ def _tf_train_step(self, experience: Dict[str, Any]) -> Dict[str, Numeric]: bc_loss = tf.reduce_mean(bc_loss) # Apply gradients to policy - variables = (*self._policy_network.trainable_variables,) # Get trainable variables + variables = ( + *self._policy_network.trainable_variables, + *self._policy_embedding_network.trainable_variables, + ) # Get trainable variables gradients = tape.gradient(bc_loss, variables) # Compute gradients. self._optimizer.apply(gradients, variables) diff --git a/og_marl/tf2/systems/idrqn.py b/og_marl/tf2/systems/idrqn.py index 4826421d..55ba3448 100644 --- a/og_marl/tf2/systems/idrqn.py +++ b/og_marl/tf2/systems/idrqn.py @@ -14,7 +14,7 @@ """Implementation of independent Q-learning (DRQN style)""" import copy -from typing import Any, Dict, Optional, Sequence, Tuple +from typing import Any, Dict, Sequence, Tuple, Optional import numpy as np import sonnet as snt @@ -37,6 +37,7 @@ switch_two_leading_dims, unroll_rnn, ) +from og_marl.tf2.networks import IdentityNetwork class IDRQNSystem(BaseMARLSystem): @@ -55,6 +56,7 @@ def __init__( eps_min: float = 0.05, eps_decay_timesteps: int = 50_000, add_agent_id_to_obs: bool = False, + observation_embedding_network: Optional[snt.Module] = None, ): super().__init__( environment, logger, add_agent_id_to_obs=add_agent_id_to_obs, discount=discount @@ -76,6 +78,12 @@ def __init__( ] ) # shared network for all agents + # Embedding network + if observation_embedding_network is None: + observation_embedding_network = IdentityNetwork() + self._q_embedding_network = observation_embedding_network + self._target_q_embedding_network = copy.deepcopy(observation_embedding_network) + # Target Q-network self._target_q_network = copy.deepcopy(self._q_network) self._target_update_period = target_update_period @@ -101,7 +109,7 @@ def reset(self) -> None: def select_actions( self, observations: Dict[str, np.ndarray], - legal_actions: Optional[Dict[str, np.ndarray]] = None, + legal_actions: Dict[str, np.ndarray], explore: bool = True, ) -> Dict[str, np.ndarray]: if explore: @@ -136,7 +144,8 @@ def _tf_select_actions( agent_observation, i, len(self._environment.possible_agents) ) agent_observation = tf.expand_dims(agent_observation, axis=0) # add batch dimension - q_values, next_rnn_states[agent] = self._q_network(agent_observation, rnn_states[agent]) + embedding = self._q_embedding_network(agent_observation) + q_values, next_rnn_states[agent] = self._q_network(embedding, rnn_states[agent]) agent_legal_actions = legal_actions[agent] masked_q_values = tf.where( @@ -149,7 +158,9 @@ def _tf_select_actions( epsilon = tf.maximum(1.0 - self._eps_dec * env_step_ctr, self._eps_min) greedy_probs = tf.one_hot(greedy_action, masked_q_values.shape[-1]) - explore_probs = agent_legal_actions / tf.reduce_sum(agent_legal_actions) + explore_probs = tf.cast( + agent_legal_actions / tf.reduce_sum(agent_legal_actions), "float32" + ) probs = (1.0 - epsilon) * greedy_probs + epsilon * explore_probs probs = tf.expand_dims(probs, axis=0) @@ -197,7 +208,8 @@ def _tf_train_step(self, train_step_ctr: int, experience: Dict[str, Any]) -> Dic resets = merge_batch_and_agent_dim_of_time_major_sequence(resets) # Unroll target network - target_qs_out = unroll_rnn(self._target_q_network, observations, resets) + target_embeddings = self._target_q_embedding_network(observations) + target_qs_out = unroll_rnn(self._target_q_network, target_embeddings, resets) # Expand batch and agent_dim target_qs_out = expand_batch_and_agent_dim_of_time_major_sequence(target_qs_out, B, N) @@ -207,7 +219,8 @@ def _tf_train_step(self, train_step_ctr: int, experience: Dict[str, Any]) -> Dic with tf.GradientTape() as tape: # Unroll online network - qs_out = unroll_rnn(self._q_network, observations, resets) + embeddings = self._q_embedding_network(observations) + qs_out = unroll_rnn(self._q_network, embeddings, resets) # Expand batch and agent_dim qs_out = expand_batch_and_agent_dim_of_time_major_sequence(qs_out, B, N) @@ -241,7 +254,10 @@ def _tf_train_step(self, train_step_ctr: int, experience: Dict[str, Any]) -> Dic loss = tf.reduce_mean(loss) # Get trainable variables - variables = (*self._q_network.trainable_variables,) + variables = ( + *self._q_network.trainable_variables, + *self._q_embedding_network.trainable_variables, + ) # Compute gradients. gradients = tape.gradient(loss, variables) @@ -250,10 +266,13 @@ def _tf_train_step(self, train_step_ctr: int, experience: Dict[str, Any]) -> Dic self._optimizer.apply(gradients, variables) # Online variables - online_variables = (*self._q_network.variables,) + online_variables = (*self._q_network.variables, *self._q_embedding_network.variables) # Get target variables - target_variables = (*self._target_q_network.variables,) + target_variables = ( + *self._target_q_network.variables, + *self._target_q_embedding_network.variables, + ) # Maybe update target network self._update_target_network(train_step_ctr, online_variables, target_variables) diff --git a/og_marl/tf2/systems/idrqn_bcq.py b/og_marl/tf2/systems/idrqn_bcq.py index 176a2ef4..a0cea06f 100644 --- a/og_marl/tf2/systems/idrqn_bcq.py +++ b/og_marl/tf2/systems/idrqn_bcq.py @@ -13,7 +13,7 @@ # limitations under the License. """Implementation of QMIX+BCQ""" -from typing import Any, Dict +from typing import Any, Dict, Optional import sonnet as snt import tensorflow as tf @@ -30,6 +30,7 @@ switch_two_leading_dims, unroll_rnn, ) +from og_marl.tf2.networks import IdentityNetwork class IDRQNBCQSystem(IDRQNSystem): @@ -47,6 +48,7 @@ def __init__( target_update_period: int = 200, learning_rate: float = 3e-4, add_agent_id_to_obs: bool = False, + observation_embedding_network: Optional[snt.Module] = None, ): super().__init__( environment, @@ -57,6 +59,7 @@ def __init__( discount=discount, target_update_period=target_update_period, learning_rate=learning_rate, + observation_embedding_network=observation_embedding_network, ) self._threshold = bc_threshold @@ -71,6 +74,10 @@ def __init__( ] ) + if observation_embedding_network is None: + observation_embedding_network = IdentityNetwork() + self._bc_embedding_network = observation_embedding_network + @tf.function(jit_compile=True) def _tf_train_step(self, train_step: int, experience: Dict[str, Any]) -> Dict[str, Numeric]: # Unpack the batch @@ -100,7 +107,8 @@ def _tf_train_step(self, train_step: int, experience: Dict[str, Any]) -> Dict[st resets = merge_batch_and_agent_dim_of_time_major_sequence(resets) # Unroll target network - target_qs_out = unroll_rnn(self._target_q_network, observations, resets) + target_embeddings = self._target_q_embedding_network(observations) + target_qs_out = unroll_rnn(self._target_q_network, target_embeddings, resets) # Expand batch and agent_dim target_qs_out = expand_batch_and_agent_dim_of_time_major_sequence(target_qs_out, B, N) @@ -110,7 +118,8 @@ def _tf_train_step(self, train_step: int, experience: Dict[str, Any]) -> Dict[st with tf.GradientTape() as tape: # Unroll online network - qs_out = unroll_rnn(self._q_network, observations, resets) + q_embeddings = self._q_embedding_network(observations) + qs_out = unroll_rnn(self._q_network, q_embeddings, resets) # Expand batch and agent_dim qs_out = expand_batch_and_agent_dim_of_time_major_sequence(qs_out, B, N) @@ -126,7 +135,8 @@ def _tf_train_step(self, train_step: int, experience: Dict[str, Any]) -> Dict[st ################### # Unroll behaviour cloning network - probs_out = unroll_rnn(self._behaviour_cloning_network, observations, resets) + bc_embeddings = self._bc_embedding_network(observations) + probs_out = unroll_rnn(self._behaviour_cloning_network, bc_embeddings, resets) # Expand batch and agent_dim probs_out = expand_batch_and_agent_dim_of_time_major_sequence(probs_out, B, N) @@ -175,7 +185,9 @@ def _tf_train_step(self, train_step: int, experience: Dict[str, Any]) -> Dict[st # Get trainable variables variables = ( *self._q_network.trainable_variables, + *self._q_embedding_network.trainable_variables, *self._behaviour_cloning_network.trainable_variables, + *self._bc_embedding_network.trainable_variables, ) # Compute gradients. @@ -185,10 +197,13 @@ def _tf_train_step(self, train_step: int, experience: Dict[str, Any]) -> Dict[st self._optimizer.apply(gradients, variables) # Online variables - online_variables = (*self._q_network.variables,) + online_variables = (*self._q_network.variables, *self._q_embedding_network.variables) # Get target variables - target_variables = (*self._target_q_network.variables,) + target_variables = ( + *self._target_q_network.variables, + *self._target_q_embedding_network.variables, + ) # Maybe update target network self._update_target_network(train_step, online_variables, target_variables) diff --git a/og_marl/tf2/systems/idrqn_cql.py b/og_marl/tf2/systems/idrqn_cql.py index bd9d3d61..1c5d1dfb 100644 --- a/og_marl/tf2/systems/idrqn_cql.py +++ b/og_marl/tf2/systems/idrqn_cql.py @@ -12,14 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. """Implementation of IDRQN+CQL""" -from typing import Any, Dict +from typing import Any, Dict, Optional import tensorflow as tf +import sonnet as snt from chex import Numeric from og_marl.environments.base import BaseEnvironment from og_marl.loggers import BaseLogger -from og_marl.tf2.systems.qmix import QMIXSystem +from og_marl.tf2.systems.idrqn import IDRQNSystem from og_marl.tf2.utils import ( batch_concat_agent_id_to_obs, expand_batch_and_agent_dim_of_time_major_sequence, @@ -30,7 +31,7 @@ ) -class IDRQNCQLSystem(QMIXSystem): +class IDRQNCQLSystem(IDRQNSystem): """IDRQN+CQL System""" @@ -42,24 +43,22 @@ def __init__( cql_weight: float = 1.0, linear_layer_dim: int = 64, recurrent_layer_dim: int = 64, - mixer_embed_dim: int = 32, - mixer_hyper_dim: int = 64, discount: float = 0.99, target_update_period: int = 200, learning_rate: float = 3e-4, add_agent_id_to_obs: bool = False, + observation_embedding_network: Optional[snt.Module] = None, ): super().__init__( environment, logger, linear_layer_dim=linear_layer_dim, recurrent_layer_dim=recurrent_layer_dim, - mixer_embed_dim=mixer_embed_dim, - mixer_hyper_dim=mixer_hyper_dim, add_agent_id_to_obs=add_agent_id_to_obs, discount=discount, target_update_period=target_update_period, learning_rate=learning_rate, + observation_embedding_network=observation_embedding_network, ) # CQL @@ -95,7 +94,8 @@ def _tf_train_step(self, train_step: int, experience: Dict[str, Any]) -> Dict[st resets = merge_batch_and_agent_dim_of_time_major_sequence(resets) # Unroll target network - target_qs_out = unroll_rnn(self._target_q_network, observations, resets) + target_embeddings = self._target_q_embedding_network(observations) + target_qs_out = unroll_rnn(self._target_q_network, target_embeddings, resets) # Expand batch and agent_dim target_qs_out = expand_batch_and_agent_dim_of_time_major_sequence(target_qs_out, B, N) @@ -105,7 +105,8 @@ def _tf_train_step(self, train_step: int, experience: Dict[str, Any]) -> Dict[st with tf.GradientTape() as tape: # Unroll online network - qs_out = unroll_rnn(self._q_network, observations, resets) + embeddings = self._q_embedding_network(observations) + qs_out = unroll_rnn(self._q_network, embeddings, resets) # Expand batch and agent_dim qs_out = expand_batch_and_agent_dim_of_time_major_sequence(qs_out, B, N) @@ -166,7 +167,10 @@ def _tf_train_step(self, train_step: int, experience: Dict[str, Any]) -> Dict[st loss = td_loss + cql_loss # Get trainable variables - variables = (*self._q_network.trainable_variables,) + variables = ( + *self._q_network.trainable_variables, + *self._q_embedding_network.trainable_variables, + ) # Compute gradients. gradients = tape.gradient(loss, variables) @@ -175,10 +179,13 @@ def _tf_train_step(self, train_step: int, experience: Dict[str, Any]) -> Dict[st self._optimizer.apply(gradients, variables) # Online variables - online_variables = (*self._q_network.variables,) + online_variables = (*self._q_network.variables, *self._q_embedding_network.variables) # Get target variables - target_variables = (*self._target_q_network.variables,) + target_variables = ( + *self._target_q_network.variables, + *self._target_q_embedding_network.variables, + ) # Maybe update target network self._update_target_network(train_step, online_variables, target_variables) diff --git a/og_marl/tf2/systems/maicq.py b/og_marl/tf2/systems/maicq.py index 4ac55c3a..de4b4051 100644 --- a/og_marl/tf2/systems/maicq.py +++ b/og_marl/tf2/systems/maicq.py @@ -13,7 +13,7 @@ # limitations under the License. """Implementation of MAICQ""" -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, Tuple, Optional import numpy as np import sonnet as snt @@ -34,6 +34,7 @@ switch_two_leading_dims, unroll_rnn, ) +from og_marl.tf2.networks import IdentityNetwork class MAICQSystem(QMIXSystem): @@ -54,6 +55,8 @@ def __init__( target_update_period: int = 200, learning_rate: float = 3e-4, add_agent_id_to_obs: bool = False, + observation_embedding_network: Optional[snt.Module] = None, + state_embedding_network: Optional[snt.Module] = None, ): super().__init__( environment, @@ -66,12 +69,19 @@ def __init__( learning_rate=learning_rate, mixer_embed_dim=mixer_embed_dim, mixer_hyper_dim=mixer_hyper_dim, + observation_embedding_network=observation_embedding_network, + state_embedding_network=state_embedding_network, ) - # ICQ + # ICQ hyper-params self._icq_advantages_beta = icq_advantages_beta self._icq_target_q_taken_beta = icq_target_q_taken_beta + # Embedding Networks + if observation_embedding_network is None: + observation_embedding_network = IdentityNetwork() + self._policy_embedding_network = observation_embedding_network + # Policy Network self._policy_network = snt.DeepRNN( [ @@ -96,7 +106,7 @@ def reset(self) -> None: def select_actions( self, observations: Dict[str, np.ndarray], - legal_actions: Optional[Dict[str, np.ndarray]] = None, + legal_actions: Dict[str, np.ndarray], explore: bool = False, ) -> Dict[str, np.ndarray]: observations = tree.map_structure(tf.convert_to_tensor, observations) @@ -123,9 +133,8 @@ def _tf_select_actions( agent_observation, i, len(self._environment.possible_agents) ) agent_observation = tf.expand_dims(agent_observation, axis=0) # add batch dimension - probs, next_rnn_states[agent] = self._policy_network( - agent_observation, rnn_states[agent] - ) + embedding = self._policy_embedding_network(agent_observation) + probs, next_rnn_states[agent] = self._policy_network(embedding, rnn_states[agent]) agent_legal_actions = legal_actions[agent] masked_probs = tf.where( @@ -173,7 +182,8 @@ def _tf_train_step( resets = merge_batch_and_agent_dim_of_time_major_sequence(resets) # Unroll target network - target_qs_out = unroll_rnn(self._target_q_network, observations, resets) + target_embeddings = self._target_q_embedding_network(observations) + target_qs_out = unroll_rnn(self._target_q_network, target_embeddings, resets) # Expand batch and agent_dim target_qs_out = expand_batch_and_agent_dim_of_time_major_sequence(target_qs_out, B, N) @@ -183,7 +193,8 @@ def _tf_train_step( with tf.GradientTape(persistent=True) as tape: # Unroll online network - qs_out = unroll_rnn(self._q_network, observations, resets) + q_embeddings = self._q_embedding_network(observations) + qs_out = unroll_rnn(self._q_network, q_embeddings, resets) # Expand batch and agent_dim qs_out = expand_batch_and_agent_dim_of_time_major_sequence(qs_out, B, N) @@ -192,7 +203,8 @@ def _tf_train_step( q_vals = switch_two_leading_dims(qs_out) # Unroll the policy - probs_out = unroll_rnn(self._policy_network, observations, resets) + policy_embeddings = self._policy_embedding_network(observations) + probs_out = unroll_rnn(self._policy_network, policy_embeddings, resets) # Expand batch and agent_dim probs_out = expand_batch_and_agent_dim_of_time_major_sequence(probs_out, B, N) @@ -216,7 +228,9 @@ def _tf_train_step( pi_taken = gather(probs_out, actions, keepdims=False) log_pi_taken = tf.math.log(pi_taken) - coe = self._mixer.k(env_states) + env_state_embeddings = self._state_embedding_network(env_states) + target_env_state_embeddings = self._target_state_embedding_network(env_states) + coe = self._mixer.k(env_state_embeddings) coma_loss = -tf.reduce_mean(coe * (len(advantages) * advantages * log_pi_taken)) @@ -225,8 +239,8 @@ def _tf_train_step( target_q_taken = gather(target_q_vals, actions, axis=-1) # Mixing critics - q_taken = self._mixer(q_taken, env_states) - target_q_taken = self._target_mixer(target_q_taken, env_states) + q_taken = self._mixer(q_taken, env_state_embeddings) + target_q_taken = self._target_mixer(target_q_taken, target_env_state_embeddings) advantage_Q = tf.nn.softmax(target_q_taken / self._icq_target_q_taken_beta, axis=0) target_q_taken = len(advantage_Q) * advantage_Q * target_q_taken @@ -252,6 +266,8 @@ def _tf_train_step( *self._policy_network.trainable_variables, *self._q_network.trainable_variables, *self._mixer.trainable_variables, + *self._q_embedding_network.trainable_variables, + *self._policy_embedding_network.trainable_variables, ) # Get trainable variables gradients = tape.gradient(loss, variables) # Compute gradients. @@ -262,12 +278,16 @@ def _tf_train_step( online_variables = ( *self._q_network.variables, *self._mixer.variables, + *self._q_embedding_network.variables, + *self._state_embedding_network.variables, ) # Get target variables target_variables = ( *self._target_q_network.variables, *self._target_mixer.variables, + *self._target_q_embedding_network.variables, + *self._target_state_embedding_network.variables, ) # Maybe update target network diff --git a/og_marl/tf2/systems/qmix.py b/og_marl/tf2/systems/qmix.py index 2ac59a6e..04ac907d 100644 --- a/og_marl/tf2/systems/qmix.py +++ b/og_marl/tf2/systems/qmix.py @@ -13,8 +13,9 @@ # limitations under the License. """Implementation of QMIX""" -from typing import Any, Dict, Tuple +from typing import Any, Dict, Tuple, Optional +import copy import sonnet as snt import tensorflow as tf from chex import Numeric @@ -31,6 +32,7 @@ switch_two_leading_dims, unroll_rnn, ) +from og_marl.tf2.networks import IdentityNetwork, QMixer class QMIXSystem(IDRQNSystem): @@ -50,6 +52,8 @@ def __init__( learning_rate: float = 3e-4, eps_decay_timesteps: int = 50_000, add_agent_id_to_obs: bool = False, + observation_embedding_network: Optional[snt.Module] = None, + state_embedding_network: Optional[snt.Module] = None, ): super().__init__( environment, @@ -61,8 +65,14 @@ def __init__( target_update_period=target_update_period, learning_rate=learning_rate, eps_decay_timesteps=eps_decay_timesteps, + observation_embedding_network=observation_embedding_network, ) + if state_embedding_network is None: + state_embedding_network = IdentityNetwork() + self._state_embedding_network = state_embedding_network + self._target_state_embedding_network = copy.deepcopy(state_embedding_network) + self._mixer = QMixer( len(self._environment.possible_agents), mixer_embed_dim, mixer_hyper_dim ) @@ -102,7 +112,8 @@ def _tf_train_step(self, train_step_ctr: int, experience: Dict[str, Any]) -> Dic resets = merge_batch_and_agent_dim_of_time_major_sequence(resets) # Unroll target network - target_qs_out = unroll_rnn(self._target_q_network, observations, resets) + target_embeddings = self._target_q_embedding_network(observations) + target_qs_out = unroll_rnn(self._target_q_network, target_embeddings, resets) # Expand batch and agent_dim target_qs_out = expand_batch_and_agent_dim_of_time_major_sequence(target_qs_out, B, N) @@ -112,7 +123,8 @@ def _tf_train_step(self, train_step_ctr: int, experience: Dict[str, Any]) -> Dic with tf.GradientTape() as tape: # Unroll online network - qs_out = unroll_rnn(self._q_network, observations, resets) + q_network_embeddings = self._q_embedding_network(observations) + qs_out = unroll_rnn(self._q_network, q_network_embeddings, resets) # Expand batch and agent_dim qs_out = expand_batch_and_agent_dim_of_time_major_sequence(qs_out, B, N) @@ -131,8 +143,16 @@ def _tf_train_step(self, train_step_ctr: int, experience: Dict[str, Any]) -> Dic target_max_qs = gather(target_qs_out, cur_max_actions, axis=-1, keepdims=False) # Q-MIXING + env_state_embeddings, target_env_state_embeddings = ( + self._state_embedding_network(env_states), + self._target_state_embedding_network(env_states), + ) chosen_action_qs, target_max_qs, rewards = self._mixing( - chosen_action_qs, target_max_qs, env_states, rewards + chosen_action_qs, + target_max_qs, + env_state_embeddings, + target_env_state_embeddings, + rewards, ) # Reduce Agent Dim @@ -150,7 +170,12 @@ def _tf_train_step(self, train_step_ctr: int, experience: Dict[str, Any]) -> Dic loss = tf.reduce_mean(loss) # Get trainable variables - variables = (*self._q_network.trainable_variables, *self._mixer.trainable_variables) + variables = ( + *self._q_network.trainable_variables, + *self._q_embedding_network.trainable_variables, + *self._mixer.trainable_variables, + *self._state_embedding_network.trainable_variables, + ) # Compute gradients. gradients = tape.gradient(loss, variables) @@ -163,13 +188,17 @@ def _tf_train_step(self, train_step_ctr: int, experience: Dict[str, Any]) -> Dic # Online variables online_variables = ( *self._q_network.variables, + *self._q_embedding_network.variables, *self._mixer.variables, + *self._state_embedding_network.variables, ) # Get target variables target_variables = ( *self._target_q_network.variables, + *self._q_embedding_network.variables, *self._target_mixer.variables, + *self._target_state_embedding_network.variables, ) # Maybe update target network @@ -185,7 +214,8 @@ def _mixing( self, chosen_action_qs: Tensor, target_max_qs: Tensor, - states: Tensor, + state_embeddings: Tensor, + target_state_embeddings: Tensor, rewards: Tensor, ) -> Tuple[Tensor, Tensor, Tensor]: """QMIX""" @@ -194,104 +224,7 @@ def _mixing( # target_max_qs = tf.reduce_sum(target_max_qs, axis=2, keepdims=True) # VDN - chosen_action_qs = self._mixer(chosen_action_qs, states) - target_max_qs = self._target_mixer(target_max_qs, states) + chosen_action_qs = self._mixer(chosen_action_qs, state_embeddings) + target_max_qs = self._target_mixer(target_max_qs, target_state_embeddings) rewards = tf.reduce_mean(rewards, axis=2, keepdims=True) return chosen_action_qs, target_max_qs, rewards - - -class QMixer(snt.Module): - - """QMIX mixing network.""" - - def __init__( - self, - num_agents: int, - embed_dim: int = 32, - hypernet_embed: int = 64, - preprocess_network: snt.Module = None, - non_monotonic: bool = False, - ): - """Initialise QMIX mixing network - - Args: - ---- - num_agents: Number of agents in the environment - state_dim: Dimensions of the global environment state - embed_dim: The dimension of the output of the first layer - of the mixer. - hypernet_embed: Number of units in the hyper network - - """ - super().__init__() - self.num_agents = num_agents - self.embed_dim = embed_dim - self.hypernet_embed = hypernet_embed - self._non_monotonic = non_monotonic - - self.hyper_w_1 = snt.Sequential( - [ - snt.Linear(self.hypernet_embed), - tf.nn.relu, - snt.Linear(self.embed_dim * self.num_agents), - ] - ) - - self.hyper_w_final = snt.Sequential( - [snt.Linear(self.hypernet_embed), tf.nn.relu, snt.Linear(self.embed_dim)] - ) - - # State dependent bias for hidden layer - self.hyper_b_1 = snt.Linear(self.embed_dim) - - # V(s) instead of a bias for the last layers - self.V = snt.Sequential([snt.Linear(self.embed_dim), tf.nn.relu, snt.Linear(1)]) - - def __call__(self, agent_qs: Tensor, states: Tensor) -> Tensor: - """Forward method.""" - B = agent_qs.shape[0] # batch size - state_dim = states.shape[2:] - - agent_qs = tf.reshape(agent_qs, (-1, 1, self.num_agents)) - - # states = tf.ones_like(states) - states = tf.reshape(states, (-1, *state_dim)) - - # First layer - w1 = self.hyper_w_1(states) - if not self._non_monotonic: - w1 = tf.abs(w1) - b1 = self.hyper_b_1(states) - w1 = tf.reshape(w1, (-1, self.num_agents, self.embed_dim)) - b1 = tf.reshape(b1, (-1, 1, self.embed_dim)) - hidden = tf.nn.elu(tf.matmul(agent_qs, w1) + b1) - - # Second layer - w_final = self.hyper_w_final(states) - if not self._non_monotonic: - w_final = tf.abs(w_final) - w_final = tf.reshape(w_final, (-1, self.embed_dim, 1)) - - # State-dependent bias - v = tf.reshape(self.V(states), (-1, 1, 1)) - - # Compute final output - y = tf.matmul(hidden, w_final) + v - - # Reshape and return - q_tot = tf.reshape(y, (B, -1, 1)) - - return q_tot - - def k(self, states: Tensor) -> Tensor: - """Method used by MAICQ.""" - B, T = states.shape[:2] - - w1 = tf.math.abs(self.hyper_w_1(states)) - w_final = tf.math.abs(self.hyper_w_final(states)) - w1 = tf.reshape(w1, shape=(-1, self.num_agents, self.embed_dim)) - w_final = tf.reshape(w_final, shape=(-1, self.embed_dim, 1)) - k = tf.matmul(w1, w_final) - k = tf.reshape(k, shape=(B, -1, self.num_agents)) - k = k / (tf.reduce_sum(k, axis=2, keepdims=True) + 1e-10) - return k diff --git a/og_marl/tf2/systems/qmix_bcq.py b/og_marl/tf2/systems/qmix_bcq.py index 715995e7..70ecad6c 100644 --- a/og_marl/tf2/systems/qmix_bcq.py +++ b/og_marl/tf2/systems/qmix_bcq.py @@ -13,7 +13,7 @@ # limitations under the License. """Implementation of QMIX+BCQ""" -from typing import Any, Dict +from typing import Any, Dict, Optional import sonnet as snt import tensorflow as tf @@ -30,6 +30,7 @@ switch_two_leading_dims, unroll_rnn, ) +from og_marl.tf2.networks import IdentityNetwork class QMIXBCQSystem(QMIXSystem): @@ -49,6 +50,8 @@ def __init__( target_update_period: int = 200, learning_rate: float = 3e-4, add_agent_id_to_obs: bool = False, + observation_embedding_network: Optional[snt.Module] = None, + state_embedding_network: Optional[snt.Module] = None, ): super().__init__( environment, @@ -61,6 +64,8 @@ def __init__( discount=discount, target_update_period=target_update_period, learning_rate=learning_rate, + observation_embedding_network=observation_embedding_network, + state_embedding_network=state_embedding_network, ) self._threshold = bc_threshold @@ -75,6 +80,10 @@ def __init__( ] ) + if observation_embedding_network is None: + observation_embedding_network = IdentityNetwork() + self._bc_embedding_network = observation_embedding_network + @tf.function(jit_compile=True) def _tf_train_step(self, train_step: int, experience: Dict[str, Any]) -> Dict[str, Numeric]: # Unpack the batch @@ -105,7 +114,8 @@ def _tf_train_step(self, train_step: int, experience: Dict[str, Any]) -> Dict[st resets = merge_batch_and_agent_dim_of_time_major_sequence(resets) # Unroll target network - target_qs_out = unroll_rnn(self._target_q_network, observations, resets) + target_embeddings = self._target_q_embedding_network(observations) + target_qs_out = unroll_rnn(self._target_q_network, target_embeddings, resets) # Expand batch and agent_dim target_qs_out = expand_batch_and_agent_dim_of_time_major_sequence(target_qs_out, B, N) @@ -115,7 +125,8 @@ def _tf_train_step(self, train_step: int, experience: Dict[str, Any]) -> Dict[st with tf.GradientTape() as tape: # Unroll online network - qs_out = unroll_rnn(self._q_network, observations, resets) + q_embeddings = self._q_embedding_network(observations) + qs_out = unroll_rnn(self._q_network, q_embeddings, resets) # Expand batch and agent_dim qs_out = expand_batch_and_agent_dim_of_time_major_sequence(qs_out, B, N) @@ -131,7 +142,8 @@ def _tf_train_step(self, train_step: int, experience: Dict[str, Any]) -> Dict[st ################### # Unroll behaviour cloning network - probs_out = unroll_rnn(self._behaviour_cloning_network, observations, resets) + bc_embeddings = self._bc_embedding_network(observations) + probs_out = unroll_rnn(self._behaviour_cloning_network, bc_embeddings, resets) # Expand batch and agent_dim probs_out = expand_batch_and_agent_dim_of_time_major_sequence(probs_out, B, N) @@ -162,9 +174,17 @@ def _tf_train_step(self, train_step: int, experience: Dict[str, Any]) -> Dict[st ####### END ####### ################### - # Maybe do mixing (e.g. QMIX) but not in independent system + # Q-MIXING + env_state_embeddings, target_env_state_embeddings = ( + self._state_embedding_network(env_states), + self._target_state_embedding_network(env_states), + ) chosen_action_qs, target_max_qs, rewards = self._mixing( - chosen_action_qs, target_max_qs, env_states, rewards + chosen_action_qs, + target_max_qs, + env_state_embeddings, + target_env_state_embeddings, + rewards, ) # Compute targets @@ -185,7 +205,9 @@ def _tf_train_step(self, train_step: int, experience: Dict[str, Any]) -> Dict[st # Get trainable variables variables = ( *self._q_network.trainable_variables, + *self._q_embedding_network.trainable_variables, *self._mixer.trainable_variables, + *self._state_embedding_network.trainable_variables, *self._behaviour_cloning_network.trainable_variables, ) @@ -198,13 +220,17 @@ def _tf_train_step(self, train_step: int, experience: Dict[str, Any]) -> Dict[st # Online variables online_variables = ( *self._q_network.variables, + *self._q_embedding_network.variables, *self._mixer.variables, + *self._state_embedding_network.variables, ) # Get target variables target_variables = ( *self._target_q_network.variables, + *self._target_q_embedding_network.variables, *self._target_mixer.variables, + *self._target_state_embedding_network.variables, ) # Maybe update target network diff --git a/og_marl/tf2/systems/qmix_cql.py b/og_marl/tf2/systems/qmix_cql.py index 4fe7abf3..9946252f 100644 --- a/og_marl/tf2/systems/qmix_cql.py +++ b/og_marl/tf2/systems/qmix_cql.py @@ -13,9 +13,10 @@ # limitations under the License. """Implementation of QMIX+CQL""" -from typing import Any, Dict +from typing import Any, Dict, Optional import tensorflow as tf +import sonnet as snt from chex import Numeric from og_marl.environments.base import BaseEnvironment @@ -52,6 +53,8 @@ def __init__( target_update_period: int = 200, learning_rate: float = 3e-4, add_agent_id_to_obs: bool = False, + observation_embedding_network: Optional[snt.Module] = None, + state_embedding_network: Optional[snt.Module] = None, ): super().__init__( environment, @@ -64,6 +67,8 @@ def __init__( discount=discount, target_update_period=target_update_period, learning_rate=learning_rate, + observation_embedding_network=observation_embedding_network, + state_embedding_network=state_embedding_network, ) # CQL @@ -100,7 +105,8 @@ def _tf_train_step(self, train_step: int, experience: Dict[str, Any]) -> Dict[st resets = merge_batch_and_agent_dim_of_time_major_sequence(resets) # Unroll target network - target_qs_out = unroll_rnn(self._target_q_network, observations, resets) + target_embeddings = self._target_q_embedding_network(observations) + target_qs_out = unroll_rnn(self._target_q_network, target_embeddings, resets) # Expand batch and agent_dim target_qs_out = expand_batch_and_agent_dim_of_time_major_sequence(target_qs_out, B, N) @@ -110,7 +116,8 @@ def _tf_train_step(self, train_step: int, experience: Dict[str, Any]) -> Dict[st with tf.GradientTape() as tape: # Unroll online network - qs_out = unroll_rnn(self._q_network, observations, resets) + q_embeddings = self._q_embedding_network(observations) + qs_out = unroll_rnn(self._q_network, q_embeddings, resets) # Expand batch and agent_dim qs_out = expand_batch_and_agent_dim_of_time_major_sequence(qs_out, B, N) @@ -128,9 +135,17 @@ def _tf_train_step(self, train_step: int, experience: Dict[str, Any]) -> Dict[st cur_max_actions = tf.argmax(qs_out_selector, axis=3) target_max_qs = gather(target_qs_out, cur_max_actions, axis=-1) - # Maybe do mixing (e.g. QMIX) but not in independent system + # Q-MIXING + env_state_embeddings, target_env_state_embeddings = ( + self._state_embedding_network(env_states), + self._target_state_embedding_network(env_states), + ) chosen_action_qs, target_max_qs, rewards = self._mixing( - chosen_action_qs, target_max_qs, env_states, rewards + chosen_action_qs, + target_max_qs, + env_state_embeddings, + target_env_state_embeddings, + rewards, ) # Compute targets @@ -159,7 +174,7 @@ def _tf_train_step(self, train_step: int, experience: Dict[str, Any]) -> Dict[st ) # [B, T, N] # Mixing - mixed_ood_qs = self._mixer(ood_qs, env_states) # [B, T, 1] + mixed_ood_qs = self._mixer(ood_qs, env_state_embeddings) # [B, T, 1] all_mixed_ood_qs.append(mixed_ood_qs) # [B, T, Ra] all_mixed_ood_qs.append(chosen_action_qs) # [B, T, Ra + 1] @@ -177,7 +192,12 @@ def _tf_train_step(self, train_step: int, experience: Dict[str, Any]) -> Dict[st loss = td_loss + cql_loss # Get trainable variables - variables = (*self._q_network.trainable_variables, *self._mixer.trainable_variables) + variables = ( + *self._q_network.trainable_variables, + *self._q_embedding_network.trainable_variables, + *self._mixer.trainable_variables, + *self._state_embedding_network.trainable_variables, + ) # Compute gradients. gradients = tape.gradient(loss, variables) @@ -188,13 +208,17 @@ def _tf_train_step(self, train_step: int, experience: Dict[str, Any]) -> Dict[st # Online variables online_variables = ( *self._q_network.variables, + *self._q_embedding_network.variables, *self._mixer.variables, + *self._state_embedding_network.variables, ) # Get target variables target_variables = ( *self._target_q_network.variables, + *self._target_q_embedding_network.variables, *self._target_mixer.variables, + *self._target_state_embedding_network.variables, ) # Maybe update target network diff --git a/setup.py b/setup.py index dc5e1cc2..73ebd824 100644 --- a/setup.py +++ b/setup.py @@ -37,6 +37,8 @@ "gymnasium", "requests", "jax[cpu]==0.4.20", + "matplotlib", + "seaborn", # "flashbax==0.1.0", # install post ], extras_require={