From 1a17b267df850f8a649869466826b9e02ee5e5b7 Mon Sep 17 00:00:00 2001 From: Claude Formanek Date: Fri, 8 Mar 2024 08:36:36 +0200 Subject: [PATCH] Added wrappers for all three environments. --- examples/tf2/online/idrqn_smax.py | 28 +++++++ examples/tf2/run_all_baselines.py | 2 +- og_marl/environments/__init__.py | 59 ++++++++++++++ og_marl/environments/jaxmarl_smax.py | 98 +++++++++++++++++++++++ og_marl/environments/jumanji_lbf.py | 111 ++++++++++++++++++++++++++ og_marl/environments/jumanji_rware.py | 110 +++++++++++++++++++++++++ og_marl/environments/utils.py | 47 ----------- og_marl/offline_dataset.py | 4 +- og_marl/tf2/utils.py | 19 +++-- 9 files changed, 423 insertions(+), 55 deletions(-) create mode 100644 examples/tf2/online/idrqn_smax.py create mode 100644 og_marl/environments/jaxmarl_smax.py create mode 100644 og_marl/environments/jumanji_lbf.py create mode 100644 og_marl/environments/jumanji_rware.py delete mode 100644 og_marl/environments/utils.py diff --git a/examples/tf2/online/idrqn_smax.py b/examples/tf2/online/idrqn_smax.py new file mode 100644 index 00000000..2b2c9a47 --- /dev/null +++ b/examples/tf2/online/idrqn_smax.py @@ -0,0 +1,28 @@ +# Copyright 2023 InstaDeep Ltd. All rights reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from og_marl.environments.jaxmarl_smax import SMAX +from og_marl.loggers import WandbLogger +from og_marl.replay_buffers import FlashbaxReplayBuffer +from og_marl.tf2.systems.qmix import QMIXSystem + +env = SMAX("3m") + +logger = WandbLogger(entity="claude_formanek") + +system = QMIXSystem(env, logger, eps_decay_timesteps=50_000) + +replay_buffer = FlashbaxReplayBuffer(sequence_length=20) + +system.train_online(replay_buffer) diff --git a/examples/tf2/run_all_baselines.py b/examples/tf2/run_all_baselines.py index b8cf2a55..51caece7 100644 --- a/examples/tf2/run_all_baselines.py +++ b/examples/tf2/run_all_baselines.py @@ -1,6 +1,6 @@ import os -from og_marl.environments.utils import get_environment +from og_marl.environments import get_environment from og_marl.loggers import JsonWriter, WandbLogger from og_marl.replay_buffers import FlashbaxReplayBuffer from og_marl.tf2.systems import get_system diff --git a/og_marl/environments/__init__.py b/og_marl/environments/__init__.py index e69de29b..c1439ff2 100644 --- a/og_marl/environments/__init__.py +++ b/og_marl/environments/__init__.py @@ -0,0 +1,59 @@ +# type: ignore + +# Copyright 2023 InstaDeep Ltd. All rights reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from og_marl.environments.base import BaseEnvironment + + +def get_environment(env_name: str, scenario: str) -> BaseEnvironment: + if env_name == "smac_v1": + from og_marl.environments.smacv1 import SMACv1 + + return SMACv1(scenario) + elif env_name == "smac_v2": + from og_marl.environments.smacv2 import SMACv2 + + return SMACv2(scenario) + elif env_name == "mamujoco": + from og_marl.environments.old_mamujoco import MAMuJoCo + + return MAMuJoCo(scenario) + elif env_name == "gymnasium_mamujoco": + from og_marl.environments.gymnasium_mamujoco import MAMuJoCo + + return MAMuJoCo(scenario) + elif env_name == "flatland": + from og_marl.environments.flatland_wrapper import Flatland + + return Flatland(scenario) + elif env_name == "voltage_control": + from og_marl.environments.voltage_control import VoltageControlEnv + + return VoltageControlEnv() + elif env_name == "smax": + from og_marl.environments.jaxmarl_smax import SMAX + + return SMAX(scenario) + elif env_name == "lbf": + from og_marl.environments.jumanji_lbf import JumanjiLBF + + return JumanjiLBF(scenario) + elif env_name == "rware": + from og_marl.environments.jumanji_rware import JumanjiRware + + return JumanjiRware(scenario) + else: + raise ValueError("Environment not recognised.") diff --git a/og_marl/environments/jaxmarl_smax.py b/og_marl/environments/jaxmarl_smax.py new file mode 100644 index 00000000..29f05fd6 --- /dev/null +++ b/og_marl/environments/jaxmarl_smax.py @@ -0,0 +1,98 @@ +# Copyright 2023 InstaDeep Ltd. All rights reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Base wrapper for Jumanji environments.""" +from typing import Any, Dict + +import jax +import numpy as np +from jaxmarl import make +from jaxmarl.environments.smax import map_name_to_scenario +from og_marl.environments.base import BaseEnvironment, ResetReturn, StepReturn + + +class SMAX(BaseEnvironment): + + """Environment wrapper for Jumanji environments.""" + + def __init__(self, scenario_name: str = "3m", seed: int = 0) -> None: + """Constructor.""" + scenario = map_name_to_scenario(scenario_name) + + self._environment = make( + "HeuristicEnemySMAX", + enemy_shoots=True, + scenario=scenario, + use_self_play_reward=False, + walls_cause_death=True, + see_enemy_actions=False, + ) + + self._num_agents = self._environment.num_agents + self.possible_agents = self._environment.agents + self._num_actions = int(self._environment.action_spaces[self.possible_agents[0]].n) + + self._state = ... # Jaxmarl environment state + + self.info_spec: Dict[str, Any] = {} # TODO add global state spec + + self._key = jax.random.PRNGKey(seed) + + self._env_step = jax.jit(self._environment.step) + + def reset(self) -> ResetReturn: + """Resets the env.""" + # Reset the environment + self._key, sub_key = jax.random.split(self._key) + obs, self._state = self._environment.reset(sub_key) + + observations = { + agent: np.asarray(obs[agent], dtype=np.float32) for agent in self.possible_agents + } + legals = { + agent: np.array(legal, "int64") + for agent, legal in self._environment.get_avail_actions(self._state).items() + } + state = np.asarray(obs["world_state"], "float32") + + # Infos + info = {"legals": legals, "state": state} + + return observations, info + + def step(self, actions: Dict[str, np.ndarray]) -> StepReturn: + """Steps in env.""" + self._key, sub_key = jax.random.split(self._key) + + # Step the environment + obs, self._state, reward, done, infos = self._environment.step( + sub_key, self._state, actions + ) + + observations = { + agent: np.asarray(obs[agent], dtype=np.float32) for agent in self.possible_agents + } + legals = { + agent: np.array(legal, "int64") + for agent, legal in self._environment.get_avail_actions(self._state).items() + } + state = np.asarray(obs["world_state"], "float32") + + # Infos + info = {"legals": legals, "state": state} + + rewards = {agent: reward[agent] for agent in self.possible_agents} + terminals = {agent: done["__all__"] for agent in self.possible_agents} + truncations = {agent: False for agent in self.possible_agents} + + return observations, rewards, terminals, truncations, info diff --git a/og_marl/environments/jumanji_lbf.py b/og_marl/environments/jumanji_lbf.py new file mode 100644 index 00000000..f74da913 --- /dev/null +++ b/og_marl/environments/jumanji_lbf.py @@ -0,0 +1,111 @@ +# Copyright 2023 InstaDeep Ltd. All rights reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Base wrapper for Jumanji LBF.""" +from typing import Any, Dict + +import jax +import jax.numpy as jnp +import jumanji +import numpy as np +from jumanji.environments.routing.lbf.generator import RandomGenerator + +from og_marl.environments.base import BaseEnvironment, ResetReturn, StepReturn + +task_configs = { + "8x8-2p-2f-coop": { + "grid_size": 8, # size of the grid to generate. + "fov": 8, # field of view of an agent. + "num_agents": 2, # number of agents on the grid. + "num_food": 2, # number of food in the environment. + "max_agent_level": 2, # maximum level of the agents (inclusive). + "force_coop": True, # force cooperation between agents. + }, + "15x15-4p-5f": { + "grid_size": 15, # size of the grid to generate. + "fov": 15, # field of view of an agent. + "num_agents": 4, # number of agents on the grid. + "num_food": 5, # number of food in the environment. + "max_agent_level": 2, # maximum level of the agents (inclusive). + "force_coop": False, # force cooperation between agents. + }, +} + + +class JumanjiLBF(BaseEnvironment): + + """Environment wrapper for Jumanji environments.""" + + def __init__(self, scenario_name: str = "2s-8x8-2p-2f-coop", seed: int = 0) -> None: + """Constructor.""" + self._environment = jumanji.make( + "LevelBasedForaging-v0", + time_limit=100, + generator=RandomGenerator(**task_configs[scenario_name]), + ) + self._num_agents = self._environment.num_agents + self._num_actions = int(self._environment.action_spec().num_values[0]) + self.possible_agents = [f"agent_{i}" for i in range(self._num_agents)] + self._state = ... # Jumanji environment state + + self.info_spec: Dict[str, Any] = {} # TODO add global state spec + + self._key = jax.random.PRNGKey(seed) + + self._env_step = jax.jit(self._environment.step) + + def reset(self) -> ResetReturn: + """Resets the env.""" + # Reset the environment + self._key, sub_key = jax.random.split(self._key) + self._state, timestep = self._environment.reset(sub_key) + + observations = { + agent: np.asarray(timestep.observation.agents_view[i], dtype=np.float32) + for i, agent in enumerate(self.possible_agents) + } + legals = { + agent: np.asarray(timestep.observation.action_mask[i], dtype=np.int32) + for i, agent in enumerate(self.possible_agents) + } + + # Infos + info = {"legals": legals} + + return observations, info + + def step(self, actions: Dict[str, np.ndarray]) -> StepReturn: + """Steps in env.""" + actions = jnp.array([actions[agent] for agent in self.possible_agents]) + # Step the environment + self._state, timestep = self._env_step(self._state, actions) + + observations = { + agent: np.asarray(timestep.observation.agents_view[i], dtype=np.float32) + for i, agent in enumerate(self.possible_agents) + } + legals = { + agent: np.asarray(timestep.observation.action_mask[i], dtype=np.int32) + for i, agent in enumerate(self.possible_agents) + } + rewards = {agent: np.asarray(timestep.reward) for agent in self.possible_agents} + terminals = {agent: np.asarray(timestep.last()) for agent in self.possible_agents} + truncations = {agent: np.asarray(False) for agent in self.possible_agents} + + # # Global state # TODO + # env_state = self._create_state_representation(observations) + + # Extra infos + info = {"legals": legals} + + return observations, rewards, terminals, truncations, info diff --git a/og_marl/environments/jumanji_rware.py b/og_marl/environments/jumanji_rware.py new file mode 100644 index 00000000..1998af88 --- /dev/null +++ b/og_marl/environments/jumanji_rware.py @@ -0,0 +1,110 @@ +# Copyright 2023 InstaDeep Ltd. All rights reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Base wrapper for Jumanji environments.""" +from typing import Any, Dict + +import jax +import jax.numpy as jnp +import jumanji +import numpy as np +from jumanji.environments.routing.robot_warehouse.generator import RandomGenerator + +from og_marl.environments.base import BaseEnvironment, ResetReturn, StepReturn + +task_configs = { + "tiny-4ag": { + "column_height": 8, + "shelf_rows": 1, + "shelf_columns": 3, + "num_agents": 4, + "sensor_range": 1, + "request_queue_size": 4, + }, + "tiny-2ag": { + "column_height": 8, + "shelf_rows": 1, + "shelf_columns": 3, + "num_agents": 2, + "sensor_range": 1, + "request_queue_size": 2, + }, +} + + +class JumanjiRware(BaseEnvironment): + + """Environment wrapper for Jumanji environments.""" + + def __init__(self, scenario_name: str = "tiny-4ag", seed: int = 0) -> None: + """Constructor.""" + self._environment = jumanji.make( + "RobotWarehouse-v0", + generator=RandomGenerator(**task_configs[scenario_name]), + ) + self._num_agents = self._environment.num_agents + self._num_actions = int(self._environment.action_spec().num_values[0]) + self.possible_agents = [f"agent_{i}" for i in range(self._num_agents)] + self._state = ... # Jumanji environment state + + self.info_spec: Dict[str, Any] = {} # TODO add global state spec + + self._key = jax.random.PRNGKey(seed) + + self._env_step = jax.jit(self._environment.step, donate_argnums=0) + + def reset(self) -> ResetReturn: + """Resets the env.""" + # Reset the environment + self._key, sub_key = jax.random.split(self._key) + self._state, timestep = self._environment.reset(sub_key) + + observations = { + agent: np.asarray(timestep.observation.agents_view[i], dtype=np.float32) + for i, agent in enumerate(self.possible_agents) + } + legals = { + agent: np.asarray(timestep.observation.action_mask[i], dtype=np.int32) + for i, agent in enumerate(self.possible_agents) + } + + # Infos + info = {"legals": legals} + + return observations, info + + def step(self, actions: Dict[str, np.ndarray]) -> StepReturn: + """Steps in env.""" + actions = jnp.array(list(actions.values())) # .squeeze(-1) + # Step the environment + self._state, timestep = self._env_step(self._state, actions) + + observations = { + agent: np.asarray(timestep.observation.agents_view[i], dtype=np.float32) + for i, agent in enumerate(self.possible_agents) + } + legals = { + agent: np.asarray(timestep.observation.action_mask[i], dtype=np.int32) + for i, agent in enumerate(self.possible_agents) + } + rewards = {agent: np.asarray(timestep.reward) for agent in self.possible_agents} + terminals = {agent: np.asarray(timestep.last()) for agent in self.possible_agents} + truncations = {agent: np.asarray(False) for agent in self.possible_agents} + + # # Global state # TODO + # env_state = self._create_state_representation(observations) + + # Extra infos + info = {"legals": legals} + + return observations, rewards, terminals, truncations, info diff --git a/og_marl/environments/utils.py b/og_marl/environments/utils.py deleted file mode 100644 index 9d641181..00000000 --- a/og_marl/environments/utils.py +++ /dev/null @@ -1,47 +0,0 @@ -# type: ignore - -# Copyright 2023 InstaDeep Ltd. All rights reserved. - -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at - -# http://www.apache.org/licenses/LICENSE-2.0 - -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from og_marl.environments.base import BaseEnvironment - - -def get_environment(env_name: str, scenario: str) -> BaseEnvironment: - if env_name == "smac_v1": - from og_marl.environments.smacv1 import SMACv1 - - return SMACv1(scenario) - elif env_name == "smac_v2": - from og_marl.environments.smacv2 import SMACv2 - - return SMACv2(scenario) - elif env_name == "mamujoco": - from og_marl.environments.old_mamujoco import MAMuJoCo - - return MAMuJoCo(scenario) - elif env_name == "gymnasium_mamujoco": - from og_marl.environments.gymnasium_mamujoco import MAMuJoCo - - return MAMuJoCo(scenario) - elif env_name == "flatland": - from og_marl.environments.flatland_wrapper import Flatland - - return Flatland(scenario) - elif env_name == "voltage_control": - from og_marl.environments.voltage_control import VoltageControlEnv - - return VoltageControlEnv() - else: - raise ValueError("Environment not recognised.") diff --git a/og_marl/offline_dataset.py b/og_marl/offline_dataset.py index 271f533b..300734ee 100644 --- a/og_marl/offline_dataset.py +++ b/og_marl/offline_dataset.py @@ -299,12 +299,12 @@ def download_and_unzip_vault( scenario_name: str, dataset_base_dir: str = "./vaults", ) -> None: - dataset_download_url = VAULT_INFO[env_name][scenario_name]["url"] - if check_directory_exists_and_not_empty(f"{dataset_base_dir}/{env_name}/{scenario_name}.vlt"): print(f"Vault '{dataset_base_dir}/{env_name}/{scenario_name}' already exists.") return + dataset_download_url = VAULT_INFO[env_name][scenario_name]["url"] + os.makedirs(f"{dataset_base_dir}/tmp/", exist_ok=True) os.makedirs(f"{dataset_base_dir}/{env_name}/", exist_ok=True) diff --git a/og_marl/tf2/utils.py b/og_marl/tf2/utils.py index ff72e590..71831956 100644 --- a/og_marl/tf2/utils.py +++ b/og_marl/tf2/utils.py @@ -108,7 +108,7 @@ def unroll_rnn(rnn_network: Module, inputs: Tensor, resets: Tensor) -> Tensor: rnn_network.initial_state(B)[0], # type: ignore hidden_state[0], ), - ) # hidden state wrapped im tuple + ) # hidden state wrapped in tuple return tf.stack(outputs, axis=0) # type: ignore @@ -148,25 +148,34 @@ def batched_agents(agents, batch_dict): # type: ignore "rewards": [], "terminals": [], "truncations": [], + "infos": {}, } for agent in agents: for key in batched_agents_dict: + if key == "infos": + continue batched_agents_dict[key].append(batch_dict[key][agent]) for key, value in batched_agents_dict.items(): + if key == "infos": + continue batched_agents_dict[key] = tf.stack(value, axis=2) batched_agents_dict["terminals"] = tf.cast(batched_agents_dict["terminals"], "float32") batched_agents_dict["truncations"] = tf.cast(batched_agents_dict["truncations"], "float32") if "legals" in batch_dict["infos"]: - batched_agents_dict["legals"] = [] + batched_agents_dict["infos"]["legals"] = [] for agent in agents: - batched_agents_dict["legals"].append(batch_dict["infos"]["legals"][agent]) - batched_agents_dict["legals"] = tf.stack(batched_agents_dict["legals"], axis=2) + batched_agents_dict["infos"]["legals"].append(batch_dict["infos"]["legals"][agent]) + batched_agents_dict["infos"]["legals"] = tf.stack( + batched_agents_dict["infos"]["legals"], axis=2 + ) if "state" in batch_dict["infos"]: - batched_agents_dict["state"] = tf.convert_to_tensor(batch_dict["infos"]["state"], "float32") + batched_agents_dict["infos"]["state"] = tf.convert_to_tensor( + batch_dict["infos"]["state"], "float32" + ) if "mask" in batch_dict["infos"]: batched_agents_dict["mask"] = tf.convert_to_tensor(batch_dict["infos"]["mask"], "float32")