diff --git a/Dockerfile b/Dockerfile
index efcfa61b..4b44b989 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -2,6 +2,7 @@ FROM nvidia/cuda:11.8.0-cudnn8-runtime-ubuntu22.04
# Ensure no installs try to launch interactive screen
ARG DEBIAN_FRONTEND=noninteractive
+ENV PYTHONUNBUFFERED=1
# Update packages and install python3.9 and other dependencies
RUN apt-get update -y && \
@@ -33,14 +34,18 @@ RUN pip install --quiet --upgrade pip setuptools wheel && \
pip install -e . && \
pip install flashbax==0.1.0
-ENV SC2PATH /home/app/StarCraftII
+# ENV SC2PATH /home/app/StarCraftII
# RUN ./install_environments/smacv1.sh
-RUN ./install_environments/smacv2.sh
+# RUN ./install_environments/smacv2.sh
# ENV LD_LIBRARY_PATH $LD_LIBRARY_PATH:/root/.mujoco/mujoco210/bin:/usr/lib/nvidia
# ENV SUPPRESS_GR_PROMPT 1
# RUN ./install_environments/mamujoco.sh
+RUN ./install_environments/pettingzoo.sh
+
+# RUN ./install_environments/flatland.sh
+
# Copy all code
COPY ./examples ./examples
COPY ./baselines ./baselines
diff --git a/README.md b/README.md
index 92ee6825..201306d7 100644
--- a/README.md
+++ b/README.md
@@ -104,12 +104,12 @@ We are in the process of migrating our datasets from TF Records to Flashbax Vaul
| 💣SMAC v2 | terran_5_vs_5
zerg_5_vs_5
terran_10_vs_10 | 5
5
10 | Discrete | Vector | Dense | Heterog | [source](https://github.com/oxwhirl/smacv2) |
| 🚅Flatland | 3 Trains
5 Trains | 3
5 | Discrete | Vector | Sparse | Homog | [source](https://flatland.aicrowd.com/intro.html) |
| 🐜MAMuJoCo | 2-HalfCheetah
2-Ant
4-Ant | 2
2
4 | Cont. | Vector | Dense | Heterog
Homog
Homog | [source](https://github.com/schroederdewitt/multiagent_mujoco) |
-
+| 🐻PettingZoo | Pursuit
Co-op Pong | 8
2 | Discrete
Discrete | Pixels
Pixels | Dense | Homog
Heterog | [source](https://pettingzoo.farama.org/) |
### Legacy Datasets (still to be migrated to Vault) 👴
| Environment | Scenario | Agents | Act | Obs | Reward | Types | Repo |
|-----|----|----|-----|-----|----|----|-----|
-| 🐻PettingZoo | Pursuit
Co-op Pong
PistonBall
KAZ| 8
2
15
2| Discrete
Discrete
Cont.
Discrete | Pixels
Pixels
Pixels
Vector | Dense | Homog
Heterog
Homog
Heterog| [source](https://pettingzoo.farama.org/) |
+| 🐻PettingZoo | PistonBall
KAZ| 15
2| Cont.
Discrete | Pixels
Vector | Dense | Homog
Heterog| [source](https://pettingzoo.farama.org/) |
| 🏙️CityLearn | 2022_all_phases | 17 | Cont. | Vector | Dense | Homog | [source](https://github.com/intelligent-environments-lab/CityLearn) |
| 🔌Voltage Control | case33_3min_final | 6 | Cont. | Vector | Dense | Homog | [source](https://github.com/Future-Power-Networks/MAPDN) |
| 🔴MPE | simple_adversary | 3 | Discrete. | Vector | Dense | Competitive | [source](https://pettingzoo.farama.org/environments/mpe/simple_adversary/) |
diff --git a/baselines/main.py b/baselines/main.py
index 1fdb8be0..0f439fc4 100644
--- a/baselines/main.py
+++ b/baselines/main.py
@@ -13,20 +13,21 @@
# limitations under the License.
from absl import app, flags
-from og_marl.environments.utils import get_environment
+from og_marl.environments import get_environment
from og_marl.loggers import JsonWriter, WandbLogger
from og_marl.offline_dataset import download_and_unzip_vault
from og_marl.replay_buffers import FlashbaxReplayBuffer
+from og_marl.tf2.networks import CNNEmbeddingNetwork
from og_marl.tf2.systems import get_system
from og_marl.tf2.utils import set_growing_gpu_memory
set_growing_gpu_memory()
FLAGS = flags.FLAGS
-flags.DEFINE_string("env", "smac_v1", "Environment name.")
-flags.DEFINE_string("scenario", "3m", "Environment scenario name.")
+flags.DEFINE_string("env", "pettingzoo", "Environment name.")
+flags.DEFINE_string("scenario", "pursuit", "Environment scenario name.")
flags.DEFINE_string("dataset", "Good", "Dataset type.: 'Good', 'Medium', 'Poor' or 'Replay' ")
-flags.DEFINE_string("system", "dbc", "System name.")
+flags.DEFINE_string("system", "qmix", "System name.")
flags.DEFINE_integer("seed", 42, "Seed.")
flags.DEFINE_float("trainer_steps", 5e4, "Number of training steps.")
flags.DEFINE_integer("batch_size", 64, "Number of training steps.")
@@ -43,7 +44,7 @@ def main(_):
env = get_environment(FLAGS.env, FLAGS.scenario)
- buffer = FlashbaxReplayBuffer(sequence_length=20, sample_period=2)
+ buffer = FlashbaxReplayBuffer(sequence_length=20, sample_period=1)
download_and_unzip_vault(FLAGS.env, FLAGS.scenario)
@@ -65,6 +66,9 @@ def main(_):
)
system_kwargs = {"add_agent_id_to_obs": True}
+ if FLAGS.scenario == "pursuit":
+ system_kwargs["observation_embedding_network"] = CNNEmbeddingNetwork()
+
system = get_system(FLAGS.system, env, logger, **system_kwargs)
system.train_offline(buffer, max_trainer_steps=FLAGS.trainer_steps, json_writer=json_writer)
diff --git a/examples/tf2/run_all_baselines.py b/examples/tf2/run_all_baselines.py
index 51caece7..957a2b71 100644
--- a/examples/tf2/run_all_baselines.py
+++ b/examples/tf2/run_all_baselines.py
@@ -1,32 +1,45 @@
import os
+# import module
+import traceback
+
from og_marl.environments import get_environment
-from og_marl.loggers import JsonWriter, WandbLogger
+from og_marl.loggers import TerminalLogger, JsonWriter
from og_marl.replay_buffers import FlashbaxReplayBuffer
+from og_marl.tf2.networks import CNNEmbeddingNetwork
from og_marl.tf2.systems import get_system
from og_marl.tf2.utils import set_growing_gpu_memory
set_growing_gpu_memory()
-os.environ["SUPPRESS_GR_PROMPT"] = 1
+# For MAMuJoCo
+os.environ["SUPPRESS_GR_PROMPT"] = "1"
scenario_system_configs = {
"smac_v1": {
"3m": {
- "systems": ["idrqn", "idrqn+cql", "idrqn+bcq", "qmix+cql", "qmix+bcq", "maicq"],
+ "systems": ["idrqn", "idrqn+cql", "idrqn+bcq", "qmix+cql", "qmix+bcq", "maicq", "dbc"],
"datasets": ["Good"],
- "trainer_steps": 3000,
+ "trainer_steps": 2000,
"evaluate_every": 1000,
},
},
- "mamujoco": {
- "2halfcheetah": {
- "systems": ["iddpg", "iddpg+cql", "maddpg+cql", "maddpg", "omar"],
+ "pettingzoo": {
+ "pursuit": {
+ "systems": ["idrqn", "idrqn+cql", "idrqn+bcq", "qmix+cql", "qmix+bcq", "maicq", "dbc"],
"datasets": ["Good"],
- "trainer_steps": 3000,
+ "trainer_steps": 2000,
"evaluate_every": 1000,
},
},
+ # "mamujoco": {
+ # "2halfcheetah": {
+ # "systems": ["iddpg", "iddpg+cql", "maddpg+cql", "maddpg", "omar"],
+ # "datasets": ["Good"],
+ # "trainer_steps": 3000,
+ # "evaluate_every": 1000,
+ # },
+ # },
}
seeds = [42]
@@ -44,7 +57,7 @@
"system": env_name,
"seed": seed,
}
- logger = WandbLogger(config, project="og-marl-baselines")
+ logger = TerminalLogger()
env = get_environment(env_name, scenario_name)
buffer = FlashbaxReplayBuffer(sequence_length=20, sample_period=1)
@@ -55,10 +68,18 @@
raise ValueError("Vault not found. Exiting.")
json_writer = JsonWriter(
- "logs", system_name, f"{scenario_name}_{dataset_name}", env_name, seed
+ "test_all_baselines",
+ system_name,
+ f"{scenario_name}_{dataset_name}",
+ env_name,
+ seed,
)
system_kwargs = {"add_agent_id_to_obs": True}
+
+ if scenario_name == "pursuit":
+ system_kwargs["observation_embedding_network"] = CNNEmbeddingNetwork()
+
system = get_system(system_name, env, logger, **system_kwargs)
trainer_steps = scenario_system_configs[env_name][scenario_name][
@@ -75,7 +96,7 @@
)
except: # noqa: E722
logger.close()
- print()
- print("BROKEN")
+ print("BROKEN:", env_name, scenario_name, system_name)
+ traceback.print_exc()
print()
continue
diff --git a/install_environments/requirements/pettingzoo.txt b/install_environments/requirements/pettingzoo.txt
index a0dc1340..151bd05e 100755
--- a/install_environments/requirements/pettingzoo.txt
+++ b/install_environments/requirements/pettingzoo.txt
@@ -2,7 +2,7 @@ autorom
gym
numpy
opencv-python
-pettingzoo==1.22.0
+pettingzoo==1.23.1
pygame
pymunk
scipy
diff --git a/mkdocs.yml b/mkdocs.yml
index a8bccadc..75ace30d 100644
--- a/mkdocs.yml
+++ b/mkdocs.yml
@@ -38,7 +38,6 @@ theme:
nav:
- Home: 'index.md'
- - Datasets: 'datasets.md'
- Baseline Results: 'baselines.md'
- Updates: 'updates.md'
- API Reference: 'api.md'
diff --git a/og_marl/environments/__init__.py b/og_marl/environments/__init__.py
index c1439ff2..846c161c 100644
--- a/og_marl/environments/__init__.py
+++ b/og_marl/environments/__init__.py
@@ -18,7 +18,7 @@
from og_marl.environments.base import BaseEnvironment
-def get_environment(env_name: str, scenario: str) -> BaseEnvironment:
+def get_environment(env_name: str, scenario: str) -> BaseEnvironment: # noqa: C901
if env_name == "smac_v1":
from og_marl.environments.smacv1 import SMACv1
@@ -31,6 +31,14 @@ def get_environment(env_name: str, scenario: str) -> BaseEnvironment:
from og_marl.environments.old_mamujoco import MAMuJoCo
return MAMuJoCo(scenario)
+ elif scenario == "pursuit":
+ from og_marl.environments.pursuit import Pursuit
+
+ return Pursuit()
+ elif scenario == "coop_pong":
+ from og_marl.environments.coop_pong import CooperativePong
+
+ return CooperativePong()
elif env_name == "gymnasium_mamujoco":
from og_marl.environments.gymnasium_mamujoco import MAMuJoCo
diff --git a/og_marl/environments/coop_pong.py b/og_marl/environments/coop_pong.py
new file mode 100644
index 00000000..0c259f26
--- /dev/null
+++ b/og_marl/environments/coop_pong.py
@@ -0,0 +1,114 @@
+# python3
+# Copyright 2021 InstaDeep Ltd. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Wrapper for Cooperative Pettingzoo environments."""
+from typing import Any, List, Dict
+
+import numpy as np
+from pettingzoo.butterfly import cooperative_pong_v5
+import supersuit
+
+from og_marl.environments.base import BaseEnvironment
+from og_marl.environments.base import Observations, ResetReturn, StepReturn
+
+
+class CooperativePong(BaseEnvironment):
+ """Environment wrapper PettingZoo Cooperative Pong."""
+
+ def __init__(
+ self,
+ ) -> None:
+ """Constructor."""
+ self._environment = cooperative_pong_v5.parallel_env(render_mode="rgb_array")
+ # Wrap environment with supersuit pre-process wrappers
+ self._environment = supersuit.color_reduction_v0(self._environment, mode="R")
+ self._environment = supersuit.resize_v0(self._environment, x_size=145, y_size=84)
+ self._environment = supersuit.dtype_v0(self._environment, dtype="float32")
+ self._environment = supersuit.normalize_obs_v0(self._environment)
+
+ self._agents = self._environment.possible_agents
+ self._done = False
+ self.max_episode_length = 900
+
+ def reset(self) -> ResetReturn:
+ """Resets the env."""
+ # Reset the environment
+ observations, _ = self._environment.reset() # type: ignore
+
+ # Convert observations
+ observations = self._convert_observations(observations)
+
+ # Global state
+ env_state = self._create_state_representation(observations, first=True)
+
+ # Infos
+ info = {"state": env_state, "legals": self._legals}
+
+ return observations, info
+
+ def step(self, actions: Dict[str, np.ndarray]) -> StepReturn:
+ """Steps in env."""
+ # Step the environment
+ observations, rewards, terminals, truncations, _ = self._environment.step(actions)
+
+ # Convert observations
+ observations = self._convert_observations(observations)
+
+ # Global state
+ env_state = self._create_state_representation(observations)
+
+ # Extra infos
+ info = {"state": env_state, "legals": self._legals}
+
+ return observations, rewards, terminals, truncations, info
+
+ def _create_state_representation(self, observations: Observations, first: bool = False) -> Any:
+ if first:
+ self._state_history = np.zeros((84, 145, 4), "float32")
+
+ state = np.expand_dims(observations["paddle_0"][:, :], axis=-1)
+
+ # framestacking
+ self._state_history = np.concatenate((state, self._state_history[:, :, :3]), axis=-1)
+
+ return self._state_history
+
+ def _convert_observations(self, observations: List) -> Observations:
+ """Make observations partial."""
+ processed_observations = {}
+ for agent in self._agents:
+ if agent == "paddle_0":
+ agent_obs = observations[agent][:, :110] # hide the other agent
+ else:
+ agent_obs = observations[agent][:, 35:] # hide the other agent
+
+ agent_obs = np.expand_dims(agent_obs, axis=-1)
+ processed_observations[agent] = agent_obs
+
+ return processed_observations
+
+ def __getattr__(self, name: str) -> Any:
+ """Expose any other attributes of the underlying environment.
+
+ Args:
+ name (str): attribute.
+
+ Returns:
+ Any: return attribute from env or underlying env.
+ """
+ if hasattr(self.__class__, name):
+ return self.__getattribute__(name)
+ else:
+ return getattr(self._environment, name)
diff --git a/og_marl/environments/pettingzoo_base.py b/og_marl/environments/pettingzoo_base.py
index edbfd13d..bbbef7c4 100644
--- a/og_marl/environments/pettingzoo_base.py
+++ b/og_marl/environments/pettingzoo_base.py
@@ -30,7 +30,7 @@ def __init__(self) -> None:
def reset(self) -> ResetReturn:
"""Resets the env."""
# Reset the environment
- observations = self._environment.reset() # type: ignore
+ observations, _ = self._environment.reset() # type: ignore
# Global state
env_state = self._create_state_representation(observations)
diff --git a/og_marl/environments/pursuit.py b/og_marl/environments/pursuit.py
index e6bb028e..9050cb41 100644
--- a/og_marl/environments/pursuit.py
+++ b/og_marl/environments/pursuit.py
@@ -11,24 +11,22 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
+from typing import Dict
import numpy as np
from gymnasium.spaces import Box, Discrete
from pettingzoo.sisl import pursuit_v4
-from supersuit import black_death_v3
-from og_marl.environments.base import Observations
-from og_marl.environments.pettingzoo_base import PettingZooBase
+from og_marl.environments.base import BaseEnvironment, Observations, ResetReturn, StepReturn
-class Pursuit(PettingZooBase):
+class Pursuit(BaseEnvironment):
"""Environment wrapper for Pursuit."""
def __init__(self) -> None:
"""Constructor for Pursuit"""
- self._environment = black_death_v3(pursuit_v4.parallel_env())
+ self._environment = pursuit_v4.parallel_env()
self.possible_agents = self._environment.possible_agents
self._num_actions = 5
self._obs_dim = (7, 7, 3)
@@ -38,11 +36,39 @@ def __init__(self) -> None:
agent: Box(-np.inf, np.inf, (*self._obs_dim,)) for agent in self.possible_agents
}
- self.info_spec = {"state": np.zeros(8 * 2 + 30 * 2, "float32")}
+ self._legals = {
+ agent: np.ones((self._num_actions,), "int32") for agent in self.possible_agents
+ }
+
+ self.info_spec = {"state": np.zeros(8 * 2 + 30 * 2, "float32"), "legals": self._legals}
+
+ self.max_episode_length = 500
+
+ def reset(self) -> ResetReturn:
+ """Resets the env."""
+ # Reset the environment
+ observations, _ = self._environment.reset() # type: ignore
+
+ # Global state
+ env_state = self._create_state_representation(observations)
+
+ # Infos
+ info = {"state": env_state, "legals": self._legals}
+
+ return observations, info
+
+ def step(self, actions: Dict[str, np.ndarray]) -> StepReturn:
+ """Steps in env."""
+ # Step the environment
+ observations, rewards, terminals, truncations, _ = self._environment.step(actions)
+
+ # Global state
+ env_state = self._create_state_representation(observations)
+
+ # Extra infos
+ info = {"state": env_state, "legals": self._legals}
- def _convert_observations(self, observations: Observations, done: bool) -> Observations:
- """Convert observations."""
- return observations
+ return observations, rewards, terminals, truncations, info
def _create_state_representation(self, observations: Observations) -> np.ndarray:
pursuer_pos = [
diff --git a/og_marl/offline_dataset.py b/og_marl/offline_dataset.py
index 300734ee..d209a07c 100644
--- a/og_marl/offline_dataset.py
+++ b/og_marl/offline_dataset.py
@@ -15,22 +15,17 @@
import os
import sys
import zipfile
-from pathlib import Path
-from typing import Any, Dict, List
+from typing import Dict, List
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import requests # type: ignore
import seaborn as sns
-import tensorflow as tf
-import tree
from chex import Array
from flashbax.vault import Vault
from git import Optional
-from tensorflow import DType
-from og_marl.environments.base import BaseEnvironment
VAULT_INFO = {
"smac_v1": {
@@ -64,236 +59,15 @@
"4ant": {"url": "https://s3.kao.instadeep.io/offline-marl-dataset/vaults/4ant.zip"},
},
"flatland": {
- "5trains": {"url": "https://s3.kao.instadeep.io/offline-marl-dataset/vaults/3trains.zip"},
- "2trains": {"url": "https://s3.kao.instadeep.io/offline-marl-dataset/vaults/5trains.zip"},
+ "3trains": {"url": "https://s3.kao.instadeep.io/offline-marl-dataset/vaults/3trains.zip"},
+ "5trains": {"url": "https://s3.kao.instadeep.io/offline-marl-dataset/vaults/5trains.zip"},
},
-}
-
-DATASET_INFO = {
- "smac_v1": {
- "3m": {"url": "https://tinyurl.com/3m-dataset", "sequence_length": 20, "period": 10},
- "8m": {"url": "https://tinyurl.com/8m-dataset", "sequence_length": 20, "period": 10},
- "5m_vs_6m": {
- "url": "https://tinyurl.com/5m-vs-6m-dataset",
- "sequence_length": 20,
- "period": 10,
- },
- "2s3z": {"url": "https://tinyurl.com/2s3z-dataset", "sequence_length": 20, "period": 10},
- "3s5z_vs_3s6z": {
- "url": "https://tinyurl.com/3s5z-vs-3s6z-dataset3",
- "sequence_length": 20,
- "period": 10,
- },
- "2c_vs_64zg": {
- "url": "https://tinyurl.com/2c-vs-64zg-dataset",
- "sequence_length": 20,
- "period": 10,
- },
- "27m_vs_30m": {
- "url": "https://tinyurl.com/27m-vs-30m-dataset",
- "sequence_length": 20,
- "period": 10,
- },
- },
- "smac_v2": {
- "terran_5_vs_5": {
- "url": "https://tinyurl.com/terran-5-vs-5-dataset",
- "sequence_length": 20,
- "period": 10,
- },
- "zerg_5_vs_5": {
- "url": "https://tinyurl.com/zerg-5-vs-5-dataset",
- "sequence_length": 20,
- "period": 10,
- },
- "terran_10_vs_10": {
- "url": "https://tinyurl.com/terran-10-vs-10-dataset",
- "sequence_length": 20,
- "period": 10,
- },
- },
- "flatland": {
- "3trains": {
- "url": "https://tinyurl.com/3trains-dataset",
- "sequence_length": 20, # TODO
- "period": 10,
- },
- "5trains": {
- "url": "https://tinyurl.com/5trains-dataset",
- "sequence_length": 20, # TODO
- "period": 10,
- },
- },
- "mamujoco": {
- "2halfcheetah": {
- "url": "https://tinyurl.com/2halfcheetah-dataset",
- "sequence_length": 20,
- "period": 10,
- },
- "2ant": {"url": "https://tinyurl.com/2ant-dataset", "sequence_length": 20, "period": 10},
- "4ant": {"url": "https://tinyurl.com/4ant-dataset", "sequence_length": 20, "period": 10},
- },
- "voltage_control": {
- "case33_3min_final": {
- "url": "https://tinyurl.com/case33-3min-final-dataset",
- "sequence_length": 20,
- "period": 10,
- },
+ "pettingzoo": {
+ "pursuit": {"url": "https://s3.kao.instadeep.io/offline-marl-dataset/vaults/pursuit.zip"}
},
}
-def get_schema_dtypes(environment: BaseEnvironment) -> Dict[str, DType]:
- act_type = list(environment.action_spaces.values())[0].dtype
- schema = {}
- for agent in environment.possible_agents:
- schema[agent + "_observations"] = tf.float32
- schema[agent + "_legal_actions"] = tf.float32
- schema[agent + "_actions"] = act_type
- schema[agent + "_rewards"] = tf.float32
- schema[agent + "_discounts"] = tf.float32
-
- ## Extras
- # Zero-padding mask
- schema["zero_padding_mask"] = tf.float32
-
- # Env state
- schema["env_state"] = tf.float32
-
- # Episode return
- schema["episode_return"] = tf.float32
-
- return schema
-
-
-class OfflineMARLDataset:
- def __init__(
- self,
- environment: BaseEnvironment,
- env_name: str,
- scenario_name: str,
- dataset_type: str,
- base_dataset_dir: str = "./datasets",
- ):
- self._environment = environment
- self._schema = get_schema_dtypes(environment)
- self._agents = environment.possible_agents
-
- path_to_dataset = f"{base_dataset_dir}/{env_name}/{scenario_name}/{dataset_type}"
-
- file_path = Path(path_to_dataset)
- sub_dir_to_idx = {}
- idx = 0
- for subdir in os.listdir(file_path):
- if file_path.joinpath(subdir).is_dir():
- sub_dir_to_idx[subdir] = idx
- idx += 1
-
- def get_fname_idx(file_name: str) -> int:
- dir_idx = sub_dir_to_idx[file_name.split("/")[-2]] * 1000
- return dir_idx + int(file_name.split("log_")[-1].split(".")[0])
-
- filenames = [str(file_name) for file_name in file_path.glob("**/*.tfrecord")]
- filenames = sorted(filenames, key=get_fname_idx)
-
- filename_dataset = tf.data.Dataset.from_tensor_slices(filenames)
- self.raw_dataset = filename_dataset.flat_map(
- lambda x: tf.data.TFRecordDataset(x, compression_type="GZIP").map(self._decode_fn)
- )
-
- self.period = DATASET_INFO[env_name][scenario_name]["period"]
- self.sequence_length = DATASET_INFO[env_name][scenario_name]["sequence_length"]
- self.max_episode_length = environment.max_episode_length
-
- def _decode_fn(self, record_bytes: Any) -> Dict[str, Any]:
- example = tf.io.parse_single_example(
- record_bytes,
- tree.map_structure(lambda x: tf.io.FixedLenFeature([], dtype=tf.string), self._schema),
- )
-
- for key, dtype in self._schema.items():
- example[key] = tf.io.parse_tensor(example[key], dtype)
-
- sample: Dict[str, dict] = {
- "observations": {},
- "actions": {},
- "rewards": {},
- "terminals": {},
- "truncations": {},
- "infos": {"legals": {}},
- }
- for agent in self._agents:
- sample["observations"][agent] = example[f"{agent}_observations"]
- sample["actions"][agent] = example[f"{agent}_actions"]
- sample["rewards"][agent] = example[f"{agent}_rewards"]
- sample["terminals"][agent] = 1 - example[f"{agent}_discounts"]
- sample["truncations"][agent] = tf.zeros_like(example[f"{agent}_discounts"])
- sample["infos"]["legals"][agent] = example[f"{agent}_legal_actions"]
-
- sample["infos"]["mask"] = example["zero_padding_mask"]
- sample["infos"]["state"] = example["env_state"]
- sample["infos"]["episode_return"] = example["episode_return"]
-
- return sample
-
- def __getattr__(self, name: Any) -> Any:
- """Expose any other attributes of the underlying environment.
-
- Args:
- ----
- name (str): attribute.
-
- Returns:
- -------
- Any: return attribute from env or underlying env.
-
- """
- if hasattr(self.__class__, name):
- return self.__getattribute__(name)
- else:
- return getattr(self._tf_dataset, name)
-
-
-def download_and_unzip_dataset(
- env_name: str,
- scenario_name: str,
- dataset_base_dir: str = "./datasets",
-) -> None:
- dataset_download_url = DATASET_INFO[env_name][scenario_name]["url"]
-
- # TODO add check to see if dataset exists already.
-
- os.makedirs(f"{dataset_base_dir}/tmp/", exist_ok=True)
- os.makedirs(f"{dataset_base_dir}/{env_name}/", exist_ok=True)
-
- zip_file_path = f"{dataset_base_dir}/tmp/tmp_dataset.zip"
-
- extraction_path = f"{dataset_base_dir}/{env_name}"
-
- response = requests.get(dataset_download_url, stream=True) # type: ignore
- total_length = response.headers.get("content-length")
-
- with open(zip_file_path, "wb") as file:
- if total_length is None: # no content length header
- file.write(response.content)
- else:
- dl = 0
- total_length = int(total_length) # type: ignore
- for data in response.iter_content(chunk_size=4096):
- dl += len(data)
- file.write(data)
- done = int(50 * dl / total_length) # type: ignore
- sys.stdout.write("\r[%s%s]" % ("=" * done, " " * (50 - done)))
- sys.stdout.flush()
-
- # Step 2: Unzip the file
- with zipfile.ZipFile(zip_file_path, "r") as zip_ref:
- zip_ref.extractall(extraction_path)
-
- # Optionally, delete the zip file after extraction
- os.remove(zip_file_path)
-
-
def download_and_unzip_vault(
env_name: str,
scenario_name: str,
diff --git a/og_marl/replay_buffers.py b/og_marl/replay_buffers.py
index 328e1220..dcd64325 100644
--- a/og_marl/replay_buffers.py
+++ b/og_marl/replay_buffers.py
@@ -97,14 +97,13 @@ def populate_from_vault(
).read()
# Recreate the buffer and associated pure functions
- self._max_size = self._buffer_state.current_index
self._replay_buffer = fbx.make_trajectory_buffer(
add_batch_size=1,
sample_batch_size=self._batch_size,
sample_sequence_length=self._sequence_length,
period=1,
min_length_time_axis=1,
- max_size=self._max_size,
+ max_size=self._sequence_length,
)
self._buffer_sample_fn = jax.jit(self._replay_buffer.sample)
self._buffer_add_fn = jax.jit(self._replay_buffer.add)
diff --git a/og_marl/tf2/networks.py b/og_marl/tf2/networks.py
new file mode 100644
index 00000000..2f219ff1
--- /dev/null
+++ b/og_marl/tf2/networks.py
@@ -0,0 +1,142 @@
+from typing import Sequence
+
+import tensorflow as tf
+from tensorflow import Tensor
+import sonnet as snt
+
+
+class QMixer(snt.Module):
+
+ """QMIX mixing network."""
+
+ def __init__(
+ self,
+ num_agents: int,
+ embed_dim: int = 32,
+ hypernet_embed: int = 64,
+ non_monotonic: bool = False,
+ ):
+ """Initialise QMIX mixing network
+
+ Args:
+ ----
+ num_agents: Number of agents in the environment
+ state_dim: Dimensions of the global environment state
+ embed_dim: The dimension of the output of the first layer
+ of the mixer.
+ hypernet_embed: Number of units in the hyper network
+
+ """
+ super().__init__()
+ self.num_agents = num_agents
+ self.embed_dim = embed_dim
+ self.hypernet_embed = hypernet_embed
+ self._non_monotonic = non_monotonic
+
+ self.hyper_w_1 = snt.Sequential(
+ [
+ snt.Linear(self.hypernet_embed),
+ tf.nn.relu,
+ snt.Linear(self.embed_dim * self.num_agents),
+ ]
+ )
+
+ self.hyper_w_final = snt.Sequential(
+ [snt.Linear(self.hypernet_embed), tf.nn.relu, snt.Linear(self.embed_dim)]
+ )
+
+ # State dependent bias for hidden layer
+ self.hyper_b_1 = snt.Linear(self.embed_dim)
+
+ # V(s) instead of a bias for the last layers
+ self.V = snt.Sequential([snt.Linear(self.embed_dim), tf.nn.relu, snt.Linear(1)])
+
+ def __call__(self, agent_qs: Tensor, states: Tensor) -> Tensor:
+ """Forward method."""
+ B = agent_qs.shape[0] # batch size
+ state_dim = states.shape[2:]
+
+ agent_qs = tf.reshape(agent_qs, (-1, 1, self.num_agents))
+
+ states = tf.reshape(states, (-1, *state_dim))
+
+ # First layer
+ w1 = self.hyper_w_1(states)
+ if not self._non_monotonic:
+ w1 = tf.abs(w1)
+ b1 = self.hyper_b_1(states)
+ w1 = tf.reshape(w1, (-1, self.num_agents, self.embed_dim))
+ b1 = tf.reshape(b1, (-1, 1, self.embed_dim))
+ hidden = tf.nn.elu(tf.matmul(agent_qs, w1) + b1)
+
+ # Second layer
+ w_final = self.hyper_w_final(states)
+ if not self._non_monotonic:
+ w_final = tf.abs(w_final)
+ w_final = tf.reshape(w_final, (-1, self.embed_dim, 1))
+
+ # State-dependent bias
+ v = tf.reshape(self.V(states), (-1, 1, 1))
+
+ # Compute final output
+ y = tf.matmul(hidden, w_final) + v
+
+ # Reshape and return
+ q_tot = tf.reshape(y, (B, -1, 1))
+
+ return q_tot
+
+ def k(self, states: Tensor) -> Tensor:
+ """Method used by MAICQ."""
+ B, T = states.shape[:2]
+
+ w1 = tf.math.abs(self.hyper_w_1(states))
+ w_final = tf.math.abs(self.hyper_w_final(states))
+ w1 = tf.reshape(w1, shape=(-1, self.num_agents, self.embed_dim))
+ w_final = tf.reshape(w_final, shape=(-1, self.embed_dim, 1))
+ k = tf.matmul(w1, w_final)
+ k = tf.reshape(k, shape=(B, -1, self.num_agents))
+ k = k / (tf.reduce_sum(k, axis=2, keepdims=True) + 1e-10)
+ return k
+
+
+@snt.allow_empty_variables
+class IdentityNetwork(snt.Module):
+ def __init__(self) -> None:
+ super().__init__()
+ return
+
+ def __call__(self, x: Tensor) -> Tensor:
+ return x
+
+
+class CNNEmbeddingNetwork(snt.Module):
+ def __init__(
+ self, output_channels: Sequence[int] = (8, 16), kernel_sizes: Sequence[int] = (3, 2)
+ ) -> None:
+ super().__init__()
+ assert len(output_channels) == len(kernel_sizes)
+
+ layers = []
+ for layer_i in range(len(output_channels)):
+ layers.append(snt.Conv2D(output_channels[layer_i], kernel_sizes[layer_i]))
+ layers.append(tf.nn.relu)
+ layers.append(tf.keras.layers.Flatten())
+
+ self.conv_net = snt.Sequential(layers)
+
+ def __call__(self, x: Tensor) -> Tensor:
+ """Embed a pixel-styled input into a vector using a conv net.
+
+ We assume the input has trailing dims
+ being the width, height and channel dimensions of the input.
+
+ The output shape is then given as (B,T,N,Embed)
+ """
+ leading_dims = x.shape[:-3]
+ trailing_dims = x.shape[-3:] # W,H,C
+
+ x = tf.reshape(x, shape=(-1, *trailing_dims))
+ embed = self.conv_net(x)
+ embed = tf.reshape(embed, shape=(*leading_dims, -1))
+ return embed
diff --git a/og_marl/tf2/systems/bc.py b/og_marl/tf2/systems/bc.py
index 5f8f047b..f0523b46 100644
--- a/og_marl/tf2/systems/bc.py
+++ b/og_marl/tf2/systems/bc.py
@@ -19,6 +19,7 @@
switch_two_leading_dims,
unroll_rnn,
)
+from og_marl.tf2.networks import IdentityNetwork
class DicreteActionBehaviourCloning(BaseMARLSystem):
@@ -33,6 +34,7 @@ def __init__(
discount: float = 0.99,
learning_rate: float = 1e-3,
add_agent_id_to_obs: bool = True,
+ observation_embedding_network: Optional[snt.Module] = None,
):
super().__init__(
environment, logger, discount=discount, add_agent_id_to_obs=add_agent_id_to_obs
@@ -48,6 +50,9 @@ def __init__(
snt.Linear(self._environment._num_actions),
]
) # shared network for all agents
+ if observation_embedding_network is None:
+ observation_embedding_network = IdentityNetwork()
+ self._policy_embedding_network = observation_embedding_network
self._optimizer = snt.optimizers.RMSProp(learning_rate=learning_rate)
@@ -101,14 +106,13 @@ def _tf_select_actions(
agent_observation, i, len(self._environment.possible_agents)
)
agent_observation = tf.expand_dims(agent_observation, axis=0) # add batch dimension
- logits, next_rnn_states[agent] = self._policy_network(
- agent_observation, rnn_states[agent]
- )
+ embedding = self._policy_embedding_network(agent_observation)
+ logits, next_rnn_states[agent] = self._policy_network(embedding, rnn_states[agent])
probs = tf.nn.softmax(logits)
if legal_actions is not None:
- agent_legals = tf.expand_dims(legal_actions[agent], axis=0)
+ agent_legals = tf.cast(tf.expand_dims(legal_actions[agent], axis=0), "float32")
probs = (probs * agent_legals) / tf.reduce_sum(
probs * agent_legals
) # mask and renorm
@@ -116,7 +120,7 @@ def _tf_select_actions(
action = tfp.distributions.Categorical(probs=probs).sample(1)
# Store agent action
- actions[agent] = action
+ actions[agent] = action[0]
return actions, next_rnn_states
@@ -147,11 +151,16 @@ def _tf_train_step(self, experience: Dict[str, Any]) -> Dict[str, Numeric]:
resets = switch_two_leading_dims(resets)
actions = switch_two_leading_dims(actions)
+ # Merge batch_dim and agent_dim
+ observations = merge_batch_and_agent_dim_of_time_major_sequence(observations)
+ resets = merge_batch_and_agent_dim_of_time_major_sequence(resets)
+
with tf.GradientTape() as tape:
+ embeddings = self._policy_embedding_network(observations)
probs_out = unroll_rnn(
self._policy_network,
- merge_batch_and_agent_dim_of_time_major_sequence(observations),
- merge_batch_and_agent_dim_of_time_major_sequence(resets),
+ embeddings,
+ resets,
)
probs_out = expand_batch_and_agent_dim_of_time_major_sequence(probs_out, B, N)
@@ -163,7 +172,10 @@ def _tf_train_step(self, experience: Dict[str, Any]) -> Dict[str, Numeric]:
bc_loss = tf.reduce_mean(bc_loss)
# Apply gradients to policy
- variables = (*self._policy_network.trainable_variables,) # Get trainable variables
+ variables = (
+ *self._policy_network.trainable_variables,
+ *self._policy_embedding_network.trainable_variables,
+ ) # Get trainable variables
gradients = tape.gradient(bc_loss, variables) # Compute gradients.
self._optimizer.apply(gradients, variables)
diff --git a/og_marl/tf2/systems/idrqn.py b/og_marl/tf2/systems/idrqn.py
index 4826421d..55ba3448 100644
--- a/og_marl/tf2/systems/idrqn.py
+++ b/og_marl/tf2/systems/idrqn.py
@@ -14,7 +14,7 @@
"""Implementation of independent Q-learning (DRQN style)"""
import copy
-from typing import Any, Dict, Optional, Sequence, Tuple
+from typing import Any, Dict, Sequence, Tuple, Optional
import numpy as np
import sonnet as snt
@@ -37,6 +37,7 @@
switch_two_leading_dims,
unroll_rnn,
)
+from og_marl.tf2.networks import IdentityNetwork
class IDRQNSystem(BaseMARLSystem):
@@ -55,6 +56,7 @@ def __init__(
eps_min: float = 0.05,
eps_decay_timesteps: int = 50_000,
add_agent_id_to_obs: bool = False,
+ observation_embedding_network: Optional[snt.Module] = None,
):
super().__init__(
environment, logger, add_agent_id_to_obs=add_agent_id_to_obs, discount=discount
@@ -76,6 +78,12 @@ def __init__(
]
) # shared network for all agents
+ # Embedding network
+ if observation_embedding_network is None:
+ observation_embedding_network = IdentityNetwork()
+ self._q_embedding_network = observation_embedding_network
+ self._target_q_embedding_network = copy.deepcopy(observation_embedding_network)
+
# Target Q-network
self._target_q_network = copy.deepcopy(self._q_network)
self._target_update_period = target_update_period
@@ -101,7 +109,7 @@ def reset(self) -> None:
def select_actions(
self,
observations: Dict[str, np.ndarray],
- legal_actions: Optional[Dict[str, np.ndarray]] = None,
+ legal_actions: Dict[str, np.ndarray],
explore: bool = True,
) -> Dict[str, np.ndarray]:
if explore:
@@ -136,7 +144,8 @@ def _tf_select_actions(
agent_observation, i, len(self._environment.possible_agents)
)
agent_observation = tf.expand_dims(agent_observation, axis=0) # add batch dimension
- q_values, next_rnn_states[agent] = self._q_network(agent_observation, rnn_states[agent])
+ embedding = self._q_embedding_network(agent_observation)
+ q_values, next_rnn_states[agent] = self._q_network(embedding, rnn_states[agent])
agent_legal_actions = legal_actions[agent]
masked_q_values = tf.where(
@@ -149,7 +158,9 @@ def _tf_select_actions(
epsilon = tf.maximum(1.0 - self._eps_dec * env_step_ctr, self._eps_min)
greedy_probs = tf.one_hot(greedy_action, masked_q_values.shape[-1])
- explore_probs = agent_legal_actions / tf.reduce_sum(agent_legal_actions)
+ explore_probs = tf.cast(
+ agent_legal_actions / tf.reduce_sum(agent_legal_actions), "float32"
+ )
probs = (1.0 - epsilon) * greedy_probs + epsilon * explore_probs
probs = tf.expand_dims(probs, axis=0)
@@ -197,7 +208,8 @@ def _tf_train_step(self, train_step_ctr: int, experience: Dict[str, Any]) -> Dic
resets = merge_batch_and_agent_dim_of_time_major_sequence(resets)
# Unroll target network
- target_qs_out = unroll_rnn(self._target_q_network, observations, resets)
+ target_embeddings = self._target_q_embedding_network(observations)
+ target_qs_out = unroll_rnn(self._target_q_network, target_embeddings, resets)
# Expand batch and agent_dim
target_qs_out = expand_batch_and_agent_dim_of_time_major_sequence(target_qs_out, B, N)
@@ -207,7 +219,8 @@ def _tf_train_step(self, train_step_ctr: int, experience: Dict[str, Any]) -> Dic
with tf.GradientTape() as tape:
# Unroll online network
- qs_out = unroll_rnn(self._q_network, observations, resets)
+ embeddings = self._q_embedding_network(observations)
+ qs_out = unroll_rnn(self._q_network, embeddings, resets)
# Expand batch and agent_dim
qs_out = expand_batch_and_agent_dim_of_time_major_sequence(qs_out, B, N)
@@ -241,7 +254,10 @@ def _tf_train_step(self, train_step_ctr: int, experience: Dict[str, Any]) -> Dic
loss = tf.reduce_mean(loss)
# Get trainable variables
- variables = (*self._q_network.trainable_variables,)
+ variables = (
+ *self._q_network.trainable_variables,
+ *self._q_embedding_network.trainable_variables,
+ )
# Compute gradients.
gradients = tape.gradient(loss, variables)
@@ -250,10 +266,13 @@ def _tf_train_step(self, train_step_ctr: int, experience: Dict[str, Any]) -> Dic
self._optimizer.apply(gradients, variables)
# Online variables
- online_variables = (*self._q_network.variables,)
+ online_variables = (*self._q_network.variables, *self._q_embedding_network.variables)
# Get target variables
- target_variables = (*self._target_q_network.variables,)
+ target_variables = (
+ *self._target_q_network.variables,
+ *self._target_q_embedding_network.variables,
+ )
# Maybe update target network
self._update_target_network(train_step_ctr, online_variables, target_variables)
diff --git a/og_marl/tf2/systems/idrqn_bcq.py b/og_marl/tf2/systems/idrqn_bcq.py
index 176a2ef4..a0cea06f 100644
--- a/og_marl/tf2/systems/idrqn_bcq.py
+++ b/og_marl/tf2/systems/idrqn_bcq.py
@@ -13,7 +13,7 @@
# limitations under the License.
"""Implementation of QMIX+BCQ"""
-from typing import Any, Dict
+from typing import Any, Dict, Optional
import sonnet as snt
import tensorflow as tf
@@ -30,6 +30,7 @@
switch_two_leading_dims,
unroll_rnn,
)
+from og_marl.tf2.networks import IdentityNetwork
class IDRQNBCQSystem(IDRQNSystem):
@@ -47,6 +48,7 @@ def __init__(
target_update_period: int = 200,
learning_rate: float = 3e-4,
add_agent_id_to_obs: bool = False,
+ observation_embedding_network: Optional[snt.Module] = None,
):
super().__init__(
environment,
@@ -57,6 +59,7 @@ def __init__(
discount=discount,
target_update_period=target_update_period,
learning_rate=learning_rate,
+ observation_embedding_network=observation_embedding_network,
)
self._threshold = bc_threshold
@@ -71,6 +74,10 @@ def __init__(
]
)
+ if observation_embedding_network is None:
+ observation_embedding_network = IdentityNetwork()
+ self._bc_embedding_network = observation_embedding_network
+
@tf.function(jit_compile=True)
def _tf_train_step(self, train_step: int, experience: Dict[str, Any]) -> Dict[str, Numeric]:
# Unpack the batch
@@ -100,7 +107,8 @@ def _tf_train_step(self, train_step: int, experience: Dict[str, Any]) -> Dict[st
resets = merge_batch_and_agent_dim_of_time_major_sequence(resets)
# Unroll target network
- target_qs_out = unroll_rnn(self._target_q_network, observations, resets)
+ target_embeddings = self._target_q_embedding_network(observations)
+ target_qs_out = unroll_rnn(self._target_q_network, target_embeddings, resets)
# Expand batch and agent_dim
target_qs_out = expand_batch_and_agent_dim_of_time_major_sequence(target_qs_out, B, N)
@@ -110,7 +118,8 @@ def _tf_train_step(self, train_step: int, experience: Dict[str, Any]) -> Dict[st
with tf.GradientTape() as tape:
# Unroll online network
- qs_out = unroll_rnn(self._q_network, observations, resets)
+ q_embeddings = self._q_embedding_network(observations)
+ qs_out = unroll_rnn(self._q_network, q_embeddings, resets)
# Expand batch and agent_dim
qs_out = expand_batch_and_agent_dim_of_time_major_sequence(qs_out, B, N)
@@ -126,7 +135,8 @@ def _tf_train_step(self, train_step: int, experience: Dict[str, Any]) -> Dict[st
###################
# Unroll behaviour cloning network
- probs_out = unroll_rnn(self._behaviour_cloning_network, observations, resets)
+ bc_embeddings = self._bc_embedding_network(observations)
+ probs_out = unroll_rnn(self._behaviour_cloning_network, bc_embeddings, resets)
# Expand batch and agent_dim
probs_out = expand_batch_and_agent_dim_of_time_major_sequence(probs_out, B, N)
@@ -175,7 +185,9 @@ def _tf_train_step(self, train_step: int, experience: Dict[str, Any]) -> Dict[st
# Get trainable variables
variables = (
*self._q_network.trainable_variables,
+ *self._q_embedding_network.trainable_variables,
*self._behaviour_cloning_network.trainable_variables,
+ *self._bc_embedding_network.trainable_variables,
)
# Compute gradients.
@@ -185,10 +197,13 @@ def _tf_train_step(self, train_step: int, experience: Dict[str, Any]) -> Dict[st
self._optimizer.apply(gradients, variables)
# Online variables
- online_variables = (*self._q_network.variables,)
+ online_variables = (*self._q_network.variables, *self._q_embedding_network.variables)
# Get target variables
- target_variables = (*self._target_q_network.variables,)
+ target_variables = (
+ *self._target_q_network.variables,
+ *self._target_q_embedding_network.variables,
+ )
# Maybe update target network
self._update_target_network(train_step, online_variables, target_variables)
diff --git a/og_marl/tf2/systems/idrqn_cql.py b/og_marl/tf2/systems/idrqn_cql.py
index bd9d3d61..1c5d1dfb 100644
--- a/og_marl/tf2/systems/idrqn_cql.py
+++ b/og_marl/tf2/systems/idrqn_cql.py
@@ -12,14 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implementation of IDRQN+CQL"""
-from typing import Any, Dict
+from typing import Any, Dict, Optional
import tensorflow as tf
+import sonnet as snt
from chex import Numeric
from og_marl.environments.base import BaseEnvironment
from og_marl.loggers import BaseLogger
-from og_marl.tf2.systems.qmix import QMIXSystem
+from og_marl.tf2.systems.idrqn import IDRQNSystem
from og_marl.tf2.utils import (
batch_concat_agent_id_to_obs,
expand_batch_and_agent_dim_of_time_major_sequence,
@@ -30,7 +31,7 @@
)
-class IDRQNCQLSystem(QMIXSystem):
+class IDRQNCQLSystem(IDRQNSystem):
"""IDRQN+CQL System"""
@@ -42,24 +43,22 @@ def __init__(
cql_weight: float = 1.0,
linear_layer_dim: int = 64,
recurrent_layer_dim: int = 64,
- mixer_embed_dim: int = 32,
- mixer_hyper_dim: int = 64,
discount: float = 0.99,
target_update_period: int = 200,
learning_rate: float = 3e-4,
add_agent_id_to_obs: bool = False,
+ observation_embedding_network: Optional[snt.Module] = None,
):
super().__init__(
environment,
logger,
linear_layer_dim=linear_layer_dim,
recurrent_layer_dim=recurrent_layer_dim,
- mixer_embed_dim=mixer_embed_dim,
- mixer_hyper_dim=mixer_hyper_dim,
add_agent_id_to_obs=add_agent_id_to_obs,
discount=discount,
target_update_period=target_update_period,
learning_rate=learning_rate,
+ observation_embedding_network=observation_embedding_network,
)
# CQL
@@ -95,7 +94,8 @@ def _tf_train_step(self, train_step: int, experience: Dict[str, Any]) -> Dict[st
resets = merge_batch_and_agent_dim_of_time_major_sequence(resets)
# Unroll target network
- target_qs_out = unroll_rnn(self._target_q_network, observations, resets)
+ target_embeddings = self._target_q_embedding_network(observations)
+ target_qs_out = unroll_rnn(self._target_q_network, target_embeddings, resets)
# Expand batch and agent_dim
target_qs_out = expand_batch_and_agent_dim_of_time_major_sequence(target_qs_out, B, N)
@@ -105,7 +105,8 @@ def _tf_train_step(self, train_step: int, experience: Dict[str, Any]) -> Dict[st
with tf.GradientTape() as tape:
# Unroll online network
- qs_out = unroll_rnn(self._q_network, observations, resets)
+ embeddings = self._q_embedding_network(observations)
+ qs_out = unroll_rnn(self._q_network, embeddings, resets)
# Expand batch and agent_dim
qs_out = expand_batch_and_agent_dim_of_time_major_sequence(qs_out, B, N)
@@ -166,7 +167,10 @@ def _tf_train_step(self, train_step: int, experience: Dict[str, Any]) -> Dict[st
loss = td_loss + cql_loss
# Get trainable variables
- variables = (*self._q_network.trainable_variables,)
+ variables = (
+ *self._q_network.trainable_variables,
+ *self._q_embedding_network.trainable_variables,
+ )
# Compute gradients.
gradients = tape.gradient(loss, variables)
@@ -175,10 +179,13 @@ def _tf_train_step(self, train_step: int, experience: Dict[str, Any]) -> Dict[st
self._optimizer.apply(gradients, variables)
# Online variables
- online_variables = (*self._q_network.variables,)
+ online_variables = (*self._q_network.variables, *self._q_embedding_network.variables)
# Get target variables
- target_variables = (*self._target_q_network.variables,)
+ target_variables = (
+ *self._target_q_network.variables,
+ *self._target_q_embedding_network.variables,
+ )
# Maybe update target network
self._update_target_network(train_step, online_variables, target_variables)
diff --git a/og_marl/tf2/systems/maicq.py b/og_marl/tf2/systems/maicq.py
index 4ac55c3a..de4b4051 100644
--- a/og_marl/tf2/systems/maicq.py
+++ b/og_marl/tf2/systems/maicq.py
@@ -13,7 +13,7 @@
# limitations under the License.
"""Implementation of MAICQ"""
-from typing import Any, Dict, Optional, Tuple
+from typing import Any, Dict, Tuple, Optional
import numpy as np
import sonnet as snt
@@ -34,6 +34,7 @@
switch_two_leading_dims,
unroll_rnn,
)
+from og_marl.tf2.networks import IdentityNetwork
class MAICQSystem(QMIXSystem):
@@ -54,6 +55,8 @@ def __init__(
target_update_period: int = 200,
learning_rate: float = 3e-4,
add_agent_id_to_obs: bool = False,
+ observation_embedding_network: Optional[snt.Module] = None,
+ state_embedding_network: Optional[snt.Module] = None,
):
super().__init__(
environment,
@@ -66,12 +69,19 @@ def __init__(
learning_rate=learning_rate,
mixer_embed_dim=mixer_embed_dim,
mixer_hyper_dim=mixer_hyper_dim,
+ observation_embedding_network=observation_embedding_network,
+ state_embedding_network=state_embedding_network,
)
- # ICQ
+ # ICQ hyper-params
self._icq_advantages_beta = icq_advantages_beta
self._icq_target_q_taken_beta = icq_target_q_taken_beta
+ # Embedding Networks
+ if observation_embedding_network is None:
+ observation_embedding_network = IdentityNetwork()
+ self._policy_embedding_network = observation_embedding_network
+
# Policy Network
self._policy_network = snt.DeepRNN(
[
@@ -96,7 +106,7 @@ def reset(self) -> None:
def select_actions(
self,
observations: Dict[str, np.ndarray],
- legal_actions: Optional[Dict[str, np.ndarray]] = None,
+ legal_actions: Dict[str, np.ndarray],
explore: bool = False,
) -> Dict[str, np.ndarray]:
observations = tree.map_structure(tf.convert_to_tensor, observations)
@@ -123,9 +133,8 @@ def _tf_select_actions(
agent_observation, i, len(self._environment.possible_agents)
)
agent_observation = tf.expand_dims(agent_observation, axis=0) # add batch dimension
- probs, next_rnn_states[agent] = self._policy_network(
- agent_observation, rnn_states[agent]
- )
+ embedding = self._policy_embedding_network(agent_observation)
+ probs, next_rnn_states[agent] = self._policy_network(embedding, rnn_states[agent])
agent_legal_actions = legal_actions[agent]
masked_probs = tf.where(
@@ -173,7 +182,8 @@ def _tf_train_step(
resets = merge_batch_and_agent_dim_of_time_major_sequence(resets)
# Unroll target network
- target_qs_out = unroll_rnn(self._target_q_network, observations, resets)
+ target_embeddings = self._target_q_embedding_network(observations)
+ target_qs_out = unroll_rnn(self._target_q_network, target_embeddings, resets)
# Expand batch and agent_dim
target_qs_out = expand_batch_and_agent_dim_of_time_major_sequence(target_qs_out, B, N)
@@ -183,7 +193,8 @@ def _tf_train_step(
with tf.GradientTape(persistent=True) as tape:
# Unroll online network
- qs_out = unroll_rnn(self._q_network, observations, resets)
+ q_embeddings = self._q_embedding_network(observations)
+ qs_out = unroll_rnn(self._q_network, q_embeddings, resets)
# Expand batch and agent_dim
qs_out = expand_batch_and_agent_dim_of_time_major_sequence(qs_out, B, N)
@@ -192,7 +203,8 @@ def _tf_train_step(
q_vals = switch_two_leading_dims(qs_out)
# Unroll the policy
- probs_out = unroll_rnn(self._policy_network, observations, resets)
+ policy_embeddings = self._policy_embedding_network(observations)
+ probs_out = unroll_rnn(self._policy_network, policy_embeddings, resets)
# Expand batch and agent_dim
probs_out = expand_batch_and_agent_dim_of_time_major_sequence(probs_out, B, N)
@@ -216,7 +228,9 @@ def _tf_train_step(
pi_taken = gather(probs_out, actions, keepdims=False)
log_pi_taken = tf.math.log(pi_taken)
- coe = self._mixer.k(env_states)
+ env_state_embeddings = self._state_embedding_network(env_states)
+ target_env_state_embeddings = self._target_state_embedding_network(env_states)
+ coe = self._mixer.k(env_state_embeddings)
coma_loss = -tf.reduce_mean(coe * (len(advantages) * advantages * log_pi_taken))
@@ -225,8 +239,8 @@ def _tf_train_step(
target_q_taken = gather(target_q_vals, actions, axis=-1)
# Mixing critics
- q_taken = self._mixer(q_taken, env_states)
- target_q_taken = self._target_mixer(target_q_taken, env_states)
+ q_taken = self._mixer(q_taken, env_state_embeddings)
+ target_q_taken = self._target_mixer(target_q_taken, target_env_state_embeddings)
advantage_Q = tf.nn.softmax(target_q_taken / self._icq_target_q_taken_beta, axis=0)
target_q_taken = len(advantage_Q) * advantage_Q * target_q_taken
@@ -252,6 +266,8 @@ def _tf_train_step(
*self._policy_network.trainable_variables,
*self._q_network.trainable_variables,
*self._mixer.trainable_variables,
+ *self._q_embedding_network.trainable_variables,
+ *self._policy_embedding_network.trainable_variables,
) # Get trainable variables
gradients = tape.gradient(loss, variables) # Compute gradients.
@@ -262,12 +278,16 @@ def _tf_train_step(
online_variables = (
*self._q_network.variables,
*self._mixer.variables,
+ *self._q_embedding_network.variables,
+ *self._state_embedding_network.variables,
)
# Get target variables
target_variables = (
*self._target_q_network.variables,
*self._target_mixer.variables,
+ *self._target_q_embedding_network.variables,
+ *self._target_state_embedding_network.variables,
)
# Maybe update target network
diff --git a/og_marl/tf2/systems/qmix.py b/og_marl/tf2/systems/qmix.py
index 2ac59a6e..04ac907d 100644
--- a/og_marl/tf2/systems/qmix.py
+++ b/og_marl/tf2/systems/qmix.py
@@ -13,8 +13,9 @@
# limitations under the License.
"""Implementation of QMIX"""
-from typing import Any, Dict, Tuple
+from typing import Any, Dict, Tuple, Optional
+import copy
import sonnet as snt
import tensorflow as tf
from chex import Numeric
@@ -31,6 +32,7 @@
switch_two_leading_dims,
unroll_rnn,
)
+from og_marl.tf2.networks import IdentityNetwork, QMixer
class QMIXSystem(IDRQNSystem):
@@ -50,6 +52,8 @@ def __init__(
learning_rate: float = 3e-4,
eps_decay_timesteps: int = 50_000,
add_agent_id_to_obs: bool = False,
+ observation_embedding_network: Optional[snt.Module] = None,
+ state_embedding_network: Optional[snt.Module] = None,
):
super().__init__(
environment,
@@ -61,8 +65,14 @@ def __init__(
target_update_period=target_update_period,
learning_rate=learning_rate,
eps_decay_timesteps=eps_decay_timesteps,
+ observation_embedding_network=observation_embedding_network,
)
+ if state_embedding_network is None:
+ state_embedding_network = IdentityNetwork()
+ self._state_embedding_network = state_embedding_network
+ self._target_state_embedding_network = copy.deepcopy(state_embedding_network)
+
self._mixer = QMixer(
len(self._environment.possible_agents), mixer_embed_dim, mixer_hyper_dim
)
@@ -102,7 +112,8 @@ def _tf_train_step(self, train_step_ctr: int, experience: Dict[str, Any]) -> Dic
resets = merge_batch_and_agent_dim_of_time_major_sequence(resets)
# Unroll target network
- target_qs_out = unroll_rnn(self._target_q_network, observations, resets)
+ target_embeddings = self._target_q_embedding_network(observations)
+ target_qs_out = unroll_rnn(self._target_q_network, target_embeddings, resets)
# Expand batch and agent_dim
target_qs_out = expand_batch_and_agent_dim_of_time_major_sequence(target_qs_out, B, N)
@@ -112,7 +123,8 @@ def _tf_train_step(self, train_step_ctr: int, experience: Dict[str, Any]) -> Dic
with tf.GradientTape() as tape:
# Unroll online network
- qs_out = unroll_rnn(self._q_network, observations, resets)
+ q_network_embeddings = self._q_embedding_network(observations)
+ qs_out = unroll_rnn(self._q_network, q_network_embeddings, resets)
# Expand batch and agent_dim
qs_out = expand_batch_and_agent_dim_of_time_major_sequence(qs_out, B, N)
@@ -131,8 +143,16 @@ def _tf_train_step(self, train_step_ctr: int, experience: Dict[str, Any]) -> Dic
target_max_qs = gather(target_qs_out, cur_max_actions, axis=-1, keepdims=False)
# Q-MIXING
+ env_state_embeddings, target_env_state_embeddings = (
+ self._state_embedding_network(env_states),
+ self._target_state_embedding_network(env_states),
+ )
chosen_action_qs, target_max_qs, rewards = self._mixing(
- chosen_action_qs, target_max_qs, env_states, rewards
+ chosen_action_qs,
+ target_max_qs,
+ env_state_embeddings,
+ target_env_state_embeddings,
+ rewards,
)
# Reduce Agent Dim
@@ -150,7 +170,12 @@ def _tf_train_step(self, train_step_ctr: int, experience: Dict[str, Any]) -> Dic
loss = tf.reduce_mean(loss)
# Get trainable variables
- variables = (*self._q_network.trainable_variables, *self._mixer.trainable_variables)
+ variables = (
+ *self._q_network.trainable_variables,
+ *self._q_embedding_network.trainable_variables,
+ *self._mixer.trainable_variables,
+ *self._state_embedding_network.trainable_variables,
+ )
# Compute gradients.
gradients = tape.gradient(loss, variables)
@@ -163,13 +188,17 @@ def _tf_train_step(self, train_step_ctr: int, experience: Dict[str, Any]) -> Dic
# Online variables
online_variables = (
*self._q_network.variables,
+ *self._q_embedding_network.variables,
*self._mixer.variables,
+ *self._state_embedding_network.variables,
)
# Get target variables
target_variables = (
*self._target_q_network.variables,
+ *self._q_embedding_network.variables,
*self._target_mixer.variables,
+ *self._target_state_embedding_network.variables,
)
# Maybe update target network
@@ -185,7 +214,8 @@ def _mixing(
self,
chosen_action_qs: Tensor,
target_max_qs: Tensor,
- states: Tensor,
+ state_embeddings: Tensor,
+ target_state_embeddings: Tensor,
rewards: Tensor,
) -> Tuple[Tensor, Tensor, Tensor]:
"""QMIX"""
@@ -194,104 +224,7 @@ def _mixing(
# target_max_qs = tf.reduce_sum(target_max_qs, axis=2, keepdims=True)
# VDN
- chosen_action_qs = self._mixer(chosen_action_qs, states)
- target_max_qs = self._target_mixer(target_max_qs, states)
+ chosen_action_qs = self._mixer(chosen_action_qs, state_embeddings)
+ target_max_qs = self._target_mixer(target_max_qs, target_state_embeddings)
rewards = tf.reduce_mean(rewards, axis=2, keepdims=True)
return chosen_action_qs, target_max_qs, rewards
-
-
-class QMixer(snt.Module):
-
- """QMIX mixing network."""
-
- def __init__(
- self,
- num_agents: int,
- embed_dim: int = 32,
- hypernet_embed: int = 64,
- preprocess_network: snt.Module = None,
- non_monotonic: bool = False,
- ):
- """Initialise QMIX mixing network
-
- Args:
- ----
- num_agents: Number of agents in the environment
- state_dim: Dimensions of the global environment state
- embed_dim: The dimension of the output of the first layer
- of the mixer.
- hypernet_embed: Number of units in the hyper network
-
- """
- super().__init__()
- self.num_agents = num_agents
- self.embed_dim = embed_dim
- self.hypernet_embed = hypernet_embed
- self._non_monotonic = non_monotonic
-
- self.hyper_w_1 = snt.Sequential(
- [
- snt.Linear(self.hypernet_embed),
- tf.nn.relu,
- snt.Linear(self.embed_dim * self.num_agents),
- ]
- )
-
- self.hyper_w_final = snt.Sequential(
- [snt.Linear(self.hypernet_embed), tf.nn.relu, snt.Linear(self.embed_dim)]
- )
-
- # State dependent bias for hidden layer
- self.hyper_b_1 = snt.Linear(self.embed_dim)
-
- # V(s) instead of a bias for the last layers
- self.V = snt.Sequential([snt.Linear(self.embed_dim), tf.nn.relu, snt.Linear(1)])
-
- def __call__(self, agent_qs: Tensor, states: Tensor) -> Tensor:
- """Forward method."""
- B = agent_qs.shape[0] # batch size
- state_dim = states.shape[2:]
-
- agent_qs = tf.reshape(agent_qs, (-1, 1, self.num_agents))
-
- # states = tf.ones_like(states)
- states = tf.reshape(states, (-1, *state_dim))
-
- # First layer
- w1 = self.hyper_w_1(states)
- if not self._non_monotonic:
- w1 = tf.abs(w1)
- b1 = self.hyper_b_1(states)
- w1 = tf.reshape(w1, (-1, self.num_agents, self.embed_dim))
- b1 = tf.reshape(b1, (-1, 1, self.embed_dim))
- hidden = tf.nn.elu(tf.matmul(agent_qs, w1) + b1)
-
- # Second layer
- w_final = self.hyper_w_final(states)
- if not self._non_monotonic:
- w_final = tf.abs(w_final)
- w_final = tf.reshape(w_final, (-1, self.embed_dim, 1))
-
- # State-dependent bias
- v = tf.reshape(self.V(states), (-1, 1, 1))
-
- # Compute final output
- y = tf.matmul(hidden, w_final) + v
-
- # Reshape and return
- q_tot = tf.reshape(y, (B, -1, 1))
-
- return q_tot
-
- def k(self, states: Tensor) -> Tensor:
- """Method used by MAICQ."""
- B, T = states.shape[:2]
-
- w1 = tf.math.abs(self.hyper_w_1(states))
- w_final = tf.math.abs(self.hyper_w_final(states))
- w1 = tf.reshape(w1, shape=(-1, self.num_agents, self.embed_dim))
- w_final = tf.reshape(w_final, shape=(-1, self.embed_dim, 1))
- k = tf.matmul(w1, w_final)
- k = tf.reshape(k, shape=(B, -1, self.num_agents))
- k = k / (tf.reduce_sum(k, axis=2, keepdims=True) + 1e-10)
- return k
diff --git a/og_marl/tf2/systems/qmix_bcq.py b/og_marl/tf2/systems/qmix_bcq.py
index 715995e7..70ecad6c 100644
--- a/og_marl/tf2/systems/qmix_bcq.py
+++ b/og_marl/tf2/systems/qmix_bcq.py
@@ -13,7 +13,7 @@
# limitations under the License.
"""Implementation of QMIX+BCQ"""
-from typing import Any, Dict
+from typing import Any, Dict, Optional
import sonnet as snt
import tensorflow as tf
@@ -30,6 +30,7 @@
switch_two_leading_dims,
unroll_rnn,
)
+from og_marl.tf2.networks import IdentityNetwork
class QMIXBCQSystem(QMIXSystem):
@@ -49,6 +50,8 @@ def __init__(
target_update_period: int = 200,
learning_rate: float = 3e-4,
add_agent_id_to_obs: bool = False,
+ observation_embedding_network: Optional[snt.Module] = None,
+ state_embedding_network: Optional[snt.Module] = None,
):
super().__init__(
environment,
@@ -61,6 +64,8 @@ def __init__(
discount=discount,
target_update_period=target_update_period,
learning_rate=learning_rate,
+ observation_embedding_network=observation_embedding_network,
+ state_embedding_network=state_embedding_network,
)
self._threshold = bc_threshold
@@ -75,6 +80,10 @@ def __init__(
]
)
+ if observation_embedding_network is None:
+ observation_embedding_network = IdentityNetwork()
+ self._bc_embedding_network = observation_embedding_network
+
@tf.function(jit_compile=True)
def _tf_train_step(self, train_step: int, experience: Dict[str, Any]) -> Dict[str, Numeric]:
# Unpack the batch
@@ -105,7 +114,8 @@ def _tf_train_step(self, train_step: int, experience: Dict[str, Any]) -> Dict[st
resets = merge_batch_and_agent_dim_of_time_major_sequence(resets)
# Unroll target network
- target_qs_out = unroll_rnn(self._target_q_network, observations, resets)
+ target_embeddings = self._target_q_embedding_network(observations)
+ target_qs_out = unroll_rnn(self._target_q_network, target_embeddings, resets)
# Expand batch and agent_dim
target_qs_out = expand_batch_and_agent_dim_of_time_major_sequence(target_qs_out, B, N)
@@ -115,7 +125,8 @@ def _tf_train_step(self, train_step: int, experience: Dict[str, Any]) -> Dict[st
with tf.GradientTape() as tape:
# Unroll online network
- qs_out = unroll_rnn(self._q_network, observations, resets)
+ q_embeddings = self._q_embedding_network(observations)
+ qs_out = unroll_rnn(self._q_network, q_embeddings, resets)
# Expand batch and agent_dim
qs_out = expand_batch_and_agent_dim_of_time_major_sequence(qs_out, B, N)
@@ -131,7 +142,8 @@ def _tf_train_step(self, train_step: int, experience: Dict[str, Any]) -> Dict[st
###################
# Unroll behaviour cloning network
- probs_out = unroll_rnn(self._behaviour_cloning_network, observations, resets)
+ bc_embeddings = self._bc_embedding_network(observations)
+ probs_out = unroll_rnn(self._behaviour_cloning_network, bc_embeddings, resets)
# Expand batch and agent_dim
probs_out = expand_batch_and_agent_dim_of_time_major_sequence(probs_out, B, N)
@@ -162,9 +174,17 @@ def _tf_train_step(self, train_step: int, experience: Dict[str, Any]) -> Dict[st
####### END #######
###################
- # Maybe do mixing (e.g. QMIX) but not in independent system
+ # Q-MIXING
+ env_state_embeddings, target_env_state_embeddings = (
+ self._state_embedding_network(env_states),
+ self._target_state_embedding_network(env_states),
+ )
chosen_action_qs, target_max_qs, rewards = self._mixing(
- chosen_action_qs, target_max_qs, env_states, rewards
+ chosen_action_qs,
+ target_max_qs,
+ env_state_embeddings,
+ target_env_state_embeddings,
+ rewards,
)
# Compute targets
@@ -185,7 +205,9 @@ def _tf_train_step(self, train_step: int, experience: Dict[str, Any]) -> Dict[st
# Get trainable variables
variables = (
*self._q_network.trainable_variables,
+ *self._q_embedding_network.trainable_variables,
*self._mixer.trainable_variables,
+ *self._state_embedding_network.trainable_variables,
*self._behaviour_cloning_network.trainable_variables,
)
@@ -198,13 +220,17 @@ def _tf_train_step(self, train_step: int, experience: Dict[str, Any]) -> Dict[st
# Online variables
online_variables = (
*self._q_network.variables,
+ *self._q_embedding_network.variables,
*self._mixer.variables,
+ *self._state_embedding_network.variables,
)
# Get target variables
target_variables = (
*self._target_q_network.variables,
+ *self._target_q_embedding_network.variables,
*self._target_mixer.variables,
+ *self._target_state_embedding_network.variables,
)
# Maybe update target network
diff --git a/og_marl/tf2/systems/qmix_cql.py b/og_marl/tf2/systems/qmix_cql.py
index 4fe7abf3..9946252f 100644
--- a/og_marl/tf2/systems/qmix_cql.py
+++ b/og_marl/tf2/systems/qmix_cql.py
@@ -13,9 +13,10 @@
# limitations under the License.
"""Implementation of QMIX+CQL"""
-from typing import Any, Dict
+from typing import Any, Dict, Optional
import tensorflow as tf
+import sonnet as snt
from chex import Numeric
from og_marl.environments.base import BaseEnvironment
@@ -52,6 +53,8 @@ def __init__(
target_update_period: int = 200,
learning_rate: float = 3e-4,
add_agent_id_to_obs: bool = False,
+ observation_embedding_network: Optional[snt.Module] = None,
+ state_embedding_network: Optional[snt.Module] = None,
):
super().__init__(
environment,
@@ -64,6 +67,8 @@ def __init__(
discount=discount,
target_update_period=target_update_period,
learning_rate=learning_rate,
+ observation_embedding_network=observation_embedding_network,
+ state_embedding_network=state_embedding_network,
)
# CQL
@@ -100,7 +105,8 @@ def _tf_train_step(self, train_step: int, experience: Dict[str, Any]) -> Dict[st
resets = merge_batch_and_agent_dim_of_time_major_sequence(resets)
# Unroll target network
- target_qs_out = unroll_rnn(self._target_q_network, observations, resets)
+ target_embeddings = self._target_q_embedding_network(observations)
+ target_qs_out = unroll_rnn(self._target_q_network, target_embeddings, resets)
# Expand batch and agent_dim
target_qs_out = expand_batch_and_agent_dim_of_time_major_sequence(target_qs_out, B, N)
@@ -110,7 +116,8 @@ def _tf_train_step(self, train_step: int, experience: Dict[str, Any]) -> Dict[st
with tf.GradientTape() as tape:
# Unroll online network
- qs_out = unroll_rnn(self._q_network, observations, resets)
+ q_embeddings = self._q_embedding_network(observations)
+ qs_out = unroll_rnn(self._q_network, q_embeddings, resets)
# Expand batch and agent_dim
qs_out = expand_batch_and_agent_dim_of_time_major_sequence(qs_out, B, N)
@@ -128,9 +135,17 @@ def _tf_train_step(self, train_step: int, experience: Dict[str, Any]) -> Dict[st
cur_max_actions = tf.argmax(qs_out_selector, axis=3)
target_max_qs = gather(target_qs_out, cur_max_actions, axis=-1)
- # Maybe do mixing (e.g. QMIX) but not in independent system
+ # Q-MIXING
+ env_state_embeddings, target_env_state_embeddings = (
+ self._state_embedding_network(env_states),
+ self._target_state_embedding_network(env_states),
+ )
chosen_action_qs, target_max_qs, rewards = self._mixing(
- chosen_action_qs, target_max_qs, env_states, rewards
+ chosen_action_qs,
+ target_max_qs,
+ env_state_embeddings,
+ target_env_state_embeddings,
+ rewards,
)
# Compute targets
@@ -159,7 +174,7 @@ def _tf_train_step(self, train_step: int, experience: Dict[str, Any]) -> Dict[st
) # [B, T, N]
# Mixing
- mixed_ood_qs = self._mixer(ood_qs, env_states) # [B, T, 1]
+ mixed_ood_qs = self._mixer(ood_qs, env_state_embeddings) # [B, T, 1]
all_mixed_ood_qs.append(mixed_ood_qs) # [B, T, Ra]
all_mixed_ood_qs.append(chosen_action_qs) # [B, T, Ra + 1]
@@ -177,7 +192,12 @@ def _tf_train_step(self, train_step: int, experience: Dict[str, Any]) -> Dict[st
loss = td_loss + cql_loss
# Get trainable variables
- variables = (*self._q_network.trainable_variables, *self._mixer.trainable_variables)
+ variables = (
+ *self._q_network.trainable_variables,
+ *self._q_embedding_network.trainable_variables,
+ *self._mixer.trainable_variables,
+ *self._state_embedding_network.trainable_variables,
+ )
# Compute gradients.
gradients = tape.gradient(loss, variables)
@@ -188,13 +208,17 @@ def _tf_train_step(self, train_step: int, experience: Dict[str, Any]) -> Dict[st
# Online variables
online_variables = (
*self._q_network.variables,
+ *self._q_embedding_network.variables,
*self._mixer.variables,
+ *self._state_embedding_network.variables,
)
# Get target variables
target_variables = (
*self._target_q_network.variables,
+ *self._target_q_embedding_network.variables,
*self._target_mixer.variables,
+ *self._target_state_embedding_network.variables,
)
# Maybe update target network
diff --git a/setup.py b/setup.py
index dc5e1cc2..73ebd824 100644
--- a/setup.py
+++ b/setup.py
@@ -37,6 +37,8 @@
"gymnasium",
"requests",
"jax[cpu]==0.4.20",
+ "matplotlib",
+ "seaborn",
# "flashbax==0.1.0", # install post
],
extras_require={