diff --git a/baselines/main.py b/baselines/main.py index 1fdb8be0..1ebae9d9 100644 --- a/baselines/main.py +++ b/baselines/main.py @@ -14,7 +14,7 @@ from absl import app, flags from og_marl.environments.utils import get_environment -from og_marl.loggers import JsonWriter, WandbLogger +from og_marl.loggers import JsonWriter, TerminalLogger, WandbLogger from og_marl.offline_dataset import download_and_unzip_vault from og_marl.replay_buffers import FlashbaxReplayBuffer from og_marl.tf2.systems import get_system @@ -23,9 +23,9 @@ set_growing_gpu_memory() FLAGS = flags.FLAGS -flags.DEFINE_string("env", "smac_v1", "Environment name.") -flags.DEFINE_string("scenario", "3m", "Environment scenario name.") -flags.DEFINE_string("dataset", "Good", "Dataset type.: 'Good', 'Medium', 'Poor' or 'Replay' ") +flags.DEFINE_string("env", "rware", "Environment name.") +flags.DEFINE_string("scenario", "tiny-4ag", "Environment scenario name.") +flags.DEFINE_string("dataset", "Replay", "Dataset type.: 'Good', 'Medium', 'Poor' or 'Replay' ") flags.DEFINE_string("system", "dbc", "System name.") flags.DEFINE_integer("seed", 42, "Seed.") flags.DEFINE_float("trainer_steps", 5e4, "Number of training steps.") @@ -52,7 +52,7 @@ def main(_): print("Vault not found. Exiting.") return - logger = WandbLogger(project="og-marl-baselines", config=config) + logger = TerminalLogger() #WandbLogger(project="og-marl-baselines", config=config) json_writer = JsonWriter( "logs", diff --git a/examples/tf2/online/idrqn_rware.py b/examples/tf2/online/idrqn_rware.py new file mode 100644 index 00000000..d245d7d1 --- /dev/null +++ b/examples/tf2/online/idrqn_rware.py @@ -0,0 +1,28 @@ +# Copyright 2023 InstaDeep Ltd. All rights reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from og_marl.environments.jumanji_rware import JumanjiRware +from og_marl.loggers import TerminalLogger +from og_marl.replay_buffers import FlashbaxReplayBuffer +from og_marl.tf2.systems.idrqn import IDRQNSystem + +env = JumanjiRware() # todo add scenario name? + +logger = TerminalLogger() # WandbLogger(entity="claude_formanek") + +system = IDRQNSystem(env, logger, eps_decay_timesteps=10_000) + +replay_buffer = FlashbaxReplayBuffer(sequence_length=20) + +system.train_online(replay_buffer) diff --git a/og_marl/environments/jumanji_rware.py b/og_marl/environments/jumanji_rware.py index ccf8c0a4..f691220f 100644 --- a/og_marl/environments/jumanji_rware.py +++ b/og_marl/environments/jumanji_rware.py @@ -12,51 +12,91 @@ # See the License for the specific language governing permissions and # limitations under the License. """Base wrapper for Jumanji environments.""" +import time from typing import Any, Dict -import numpy as np import jax +import jax.numpy as jnp +import jumanji +import numpy as np +from jumanji.environments.routing.robot_warehouse.generator import RandomGenerator from og_marl.environments.base import BaseEnvironment, ResetReturn, StepReturn +task_configs = { + "tiny-4ag": { + "column_height": 8, + "shelf_rows": 1, + "shelf_columns": 3, + "num_agents": 4, + "sensor_range": 1, + "request_queue_size": 4, + } +} -class JumanjiBase(BaseEnvironment): +class JumanjiRware(BaseEnvironment): """Environment wrapper for Jumanji environments.""" - def __init__(self) -> None: + def __init__(self, scenario_name = "tiny-4ag", seed = 0) -> None: """Constructor.""" - self._environment = ... + self._environment = jumanji.make( + "RobotWarehouse-v0", + generator=RandomGenerator(**task_configs[scenario_name]), + ) self._num_agents = self._environment.num_agents + self._num_actions = int(self._environment.action_spec().num_values[0]) self.possible_agents = [f"agent_{i}" for i in range(self._num_agents)] self._state = ... # Jumanji environment state - self.info_spec: Dict[str, Any] = {} + self.info_spec: Dict[str, Any] = {} # TODO add global state spec + + self._key = jax.random.PRNGKey(seed) + + self._env_step = jax.jit(self._environment.step, donate_argnums=0) def reset(self) -> ResetReturn: """Resets the env.""" # Reset the environment - self.key, sub_key = jax.random.split(self.key) + self._key, sub_key = jax.random.split(self._key) self._state, timestep = self._environment.reset(sub_key) - # Infos - info = {"state": env_state} + observations = {agent: np.asarray( + timestep.observation.agents_view[i], dtype=np.float32) + for i, agent in enumerate(self.possible_agents) + } + legals = {agent: np.asarray( + timestep.observation.action_mask[i], dtype=np.int32) + for i, agent in enumerate(self.possible_agents) + } - # Convert observations to OLT format - observations = self._convert_observations(observations, False) + # Infos + info = {"legals": legals} return observations, info def step(self, actions: Dict[str, np.ndarray]) -> StepReturn: """Steps in env.""" - actions = ... # convert actions + actions = jnp.array(list(actions.values())).squeeze(-1) # Step the environment - self._state, timestep = self._env.step(self._state, actions) - - # Global state - env_state = self._create_state_representation(observations) + self._state, timestep = self._env_step(self._state, actions) + + observations = {agent: np.asarray( + timestep.observation.agents_view[i], dtype=np.float32) + for i, agent in enumerate(self.possible_agents) + } + legals = {agent: np.asarray( + timestep.observation.action_mask[i], dtype=np.int32) + for i, agent in enumerate(self.possible_agents) + } + rewards = {agent: timestep.reward for agent in self.possible_agents} + terminals = {agent: timestep.last() for agent in self.possible_agents} + truncations = {agent: False for agent in self.possible_agents} + + # # Global state + # env_state = self._create_state_representation(observations) # Extra infos - info = {"state": env_state} + info = {"legals": legals} return observations, rewards, terminals, truncations, info diff --git a/og_marl/environments/utils.py b/og_marl/environments/utils.py index 9d641181..c12e262d 100644 --- a/og_marl/environments/utils.py +++ b/og_marl/environments/utils.py @@ -43,5 +43,9 @@ def get_environment(env_name: str, scenario: str) -> BaseEnvironment: from og_marl.environments.voltage_control import VoltageControlEnv return VoltageControlEnv() + elif env_name == "rware": + from og_marl.environments.jumanji_rware import JumanjiRware + + return JumanjiRware(scenario) else: raise ValueError("Environment not recognised.") diff --git a/og_marl/offline_dataset.py b/og_marl/offline_dataset.py index 271f533b..300734ee 100644 --- a/og_marl/offline_dataset.py +++ b/og_marl/offline_dataset.py @@ -299,12 +299,12 @@ def download_and_unzip_vault( scenario_name: str, dataset_base_dir: str = "./vaults", ) -> None: - dataset_download_url = VAULT_INFO[env_name][scenario_name]["url"] - if check_directory_exists_and_not_empty(f"{dataset_base_dir}/{env_name}/{scenario_name}.vlt"): print(f"Vault '{dataset_base_dir}/{env_name}/{scenario_name}' already exists.") return + dataset_download_url = VAULT_INFO[env_name][scenario_name]["url"] + os.makedirs(f"{dataset_base_dir}/tmp/", exist_ok=True) os.makedirs(f"{dataset_base_dir}/{env_name}/", exist_ok=True) diff --git a/og_marl/tf2/systems/bc.py b/og_marl/tf2/systems/bc.py index 5f8f047b..c7a71422 100644 --- a/og_marl/tf2/systems/bc.py +++ b/og_marl/tf2/systems/bc.py @@ -108,10 +108,11 @@ def _tf_select_actions( probs = tf.nn.softmax(logits) if legal_actions is not None: - agent_legals = tf.expand_dims(legal_actions[agent], axis=0) + agent_legals = tf.cast(tf.expand_dims(legal_actions[agent], axis=0), "float32") probs = (probs * agent_legals) / tf.reduce_sum( probs * agent_legals ) # mask and renorm + probs = tf.cast(probs, "float32") action = tfp.distributions.Categorical(probs=probs).sample(1) @@ -124,7 +125,7 @@ def train_step(self, experience: Experience) -> Dict[str, Numeric]: logs = self._tf_train_step(experience) return logs # type: ignore - @tf.function(jit_compile=True) + @tf.function()#jit_compile=True) def _tf_train_step(self, experience: Dict[str, Any]) -> Dict[str, Numeric]: # Unpack the relevant quantities observations = experience["observations"] diff --git a/og_marl/tf2/systems/idrqn.py b/og_marl/tf2/systems/idrqn.py index 4826421d..6f927859 100644 --- a/og_marl/tf2/systems/idrqn.py +++ b/og_marl/tf2/systems/idrqn.py @@ -118,7 +118,7 @@ def select_actions( lambda x: x.numpy(), actions ) # convert to numpy and squeeze batch dim - @tf.function(jit_compile=True) + #@tf.function(jit_compile=True) def _tf_select_actions( self, env_step_ctr: int, @@ -149,7 +149,7 @@ def _tf_select_actions( epsilon = tf.maximum(1.0 - self._eps_dec * env_step_ctr, self._eps_min) greedy_probs = tf.one_hot(greedy_action, masked_q_values.shape[-1]) - explore_probs = agent_legal_actions / tf.reduce_sum(agent_legal_actions) + explore_probs = tf.cast(agent_legal_actions / tf.reduce_sum(agent_legal_actions), dtype=tf.float32) probs = (1.0 - epsilon) * greedy_probs + epsilon * explore_probs probs = tf.expand_dims(probs, axis=0) diff --git a/vault_conversion.ipynb b/vault_conversion.ipynb new file mode 100644 index 00000000..3db5e1c0 --- /dev/null +++ b/vault_conversion.ipynb @@ -0,0 +1,148 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loading vault found at /Users/callum/og-marl/vaults/rware/tiny-4ag.vlt/Replay\n" + ] + } + ], + "source": [ + "from flashbax.vault import Vault\n", + "vlt = Vault(\n", + " vault_name=\"rware/tiny-4ag.vlt\",\n", + " vault_uid=\"Replay\",\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "data = vlt.read()" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "dict_keys(['action', 'done', 'legal_action_mask', 'observation', 'reward'])" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data.experience.keys()" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "import jax.numpy as jnp\n", + "\n", + "def convert_timestep(input_dict):\n", + " output_dict = {\n", + " \"observations\": jnp.asarray(input_dict[\"observation\"][..., 4:], dtype=jnp.float32),\n", + " \"rewards\": input_dict[\"reward\"],\n", + " \"actions\": jnp.asarray(input_dict[\"action\"], dtype=jnp.int32),\n", + " \"terminals\": input_dict[\"done\"],\n", + " \"truncations\": jnp.zeros_like(input_dict[\"done\"], dtype=jnp.bool_),\n", + " \"infos\": {\n", + " \"legals\": jnp.asarray(input_dict[\"legal_action_mask\"], dtype=jnp.int32),\n", + " \"env_state\": input_dict[\"observation\"][..., 4:].reshape(*input_dict[\"observation\"].shape[:2], -1)\n", + " }\n", + " }\n", + " return output_dict" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "New vault created at /Users/callum/og-marl/vaults/rware/tiny-4ag-converted.vlt/Replay\n", + "Since the provided buffer state has a temporal dimension of 6256, you must write to the vault at least every 6255 timesteps to avoid data loss.\n" + ] + } + ], + "source": [ + "new_vault = Vault(\n", + " vault_name=\"rware/tiny-4ag-converted.vlt\",\n", + " vault_uid=\"Replay\",\n", + " experience_structure=convert_timestep(data.experience),\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "6256" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from flashbax.buffers.trajectory_buffer import TrajectoryBufferState\n", + "import jax\n", + "new_vault.write(\n", + " TrajectoryBufferState(\n", + " experience=convert_timestep(data.experience),\n", + " current_index=jax.tree_util.tree_flatten(data.experience)[0][0].shape[1],\n", + " is_full=False,\n", + " )\n", + ")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "ogmarl", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.18" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}