Skip to content

Commit

Permalink
wip: training with rware
Browse files Browse the repository at this point in the history
  • Loading branch information
callumtilbury committed Mar 4, 2024
1 parent 5f13430 commit fb3d307
Show file tree
Hide file tree
Showing 8 changed files with 248 additions and 27 deletions.
10 changes: 5 additions & 5 deletions baselines/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from absl import app, flags

from og_marl.environments.utils import get_environment
from og_marl.loggers import JsonWriter, WandbLogger
from og_marl.loggers import JsonWriter, TerminalLogger, WandbLogger
from og_marl.offline_dataset import download_and_unzip_vault
from og_marl.replay_buffers import FlashbaxReplayBuffer
from og_marl.tf2.systems import get_system
Expand All @@ -23,9 +23,9 @@
set_growing_gpu_memory()

FLAGS = flags.FLAGS
flags.DEFINE_string("env", "smac_v1", "Environment name.")
flags.DEFINE_string("scenario", "3m", "Environment scenario name.")
flags.DEFINE_string("dataset", "Good", "Dataset type.: 'Good', 'Medium', 'Poor' or 'Replay' ")
flags.DEFINE_string("env", "rware", "Environment name.")
flags.DEFINE_string("scenario", "tiny-4ag", "Environment scenario name.")
flags.DEFINE_string("dataset", "Replay", "Dataset type.: 'Good', 'Medium', 'Poor' or 'Replay' ")
flags.DEFINE_string("system", "dbc", "System name.")
flags.DEFINE_integer("seed", 42, "Seed.")
flags.DEFINE_float("trainer_steps", 5e4, "Number of training steps.")
Expand All @@ -52,7 +52,7 @@ def main(_):
print("Vault not found. Exiting.")
return

logger = WandbLogger(project="og-marl-baselines", config=config)
logger = TerminalLogger() #WandbLogger(project="og-marl-baselines", config=config)

json_writer = JsonWriter(
"logs",
Expand Down
28 changes: 28 additions & 0 deletions examples/tf2/online/idrqn_rware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Copyright 2023 InstaDeep Ltd. All rights reserved.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from og_marl.environments.jumanji_rware import JumanjiRware
from og_marl.loggers import TerminalLogger
from og_marl.replay_buffers import FlashbaxReplayBuffer
from og_marl.tf2.systems.idrqn import IDRQNSystem

env = JumanjiRware() # todo add scenario name?

logger = TerminalLogger() # WandbLogger(entity="claude_formanek")

system = IDRQNSystem(env, logger, eps_decay_timesteps=10_000)

replay_buffer = FlashbaxReplayBuffer(sequence_length=20)

system.train_online(replay_buffer)
72 changes: 56 additions & 16 deletions og_marl/environments/jumanji_rware.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,51 +12,91 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Base wrapper for Jumanji environments."""
import time
from typing import Any, Dict

import numpy as np
import jax
import jax.numpy as jnp
import jumanji
import numpy as np
from jumanji.environments.routing.robot_warehouse.generator import RandomGenerator

from og_marl.environments.base import BaseEnvironment, ResetReturn, StepReturn

task_configs = {
"tiny-4ag": {
"column_height": 8,
"shelf_rows": 1,
"shelf_columns": 3,
"num_agents": 4,
"sensor_range": 1,
"request_queue_size": 4,
}
}

class JumanjiBase(BaseEnvironment):
class JumanjiRware(BaseEnvironment):

"""Environment wrapper for Jumanji environments."""

def __init__(self) -> None:
def __init__(self, scenario_name = "tiny-4ag", seed = 0) -> None:
"""Constructor."""
self._environment = ...
self._environment = jumanji.make(
"RobotWarehouse-v0",
generator=RandomGenerator(**task_configs[scenario_name]),
)
self._num_agents = self._environment.num_agents
self._num_actions = int(self._environment.action_spec().num_values[0])
self.possible_agents = [f"agent_{i}" for i in range(self._num_agents)]
self._state = ... # Jumanji environment state

self.info_spec: Dict[str, Any] = {}
self.info_spec: Dict[str, Any] = {} # TODO add global state spec

self._key = jax.random.PRNGKey(seed)

self._env_step = jax.jit(self._environment.step, donate_argnums=0)

def reset(self) -> ResetReturn:
"""Resets the env."""
# Reset the environment
self.key, sub_key = jax.random.split(self.key)
self._key, sub_key = jax.random.split(self._key)
self._state, timestep = self._environment.reset(sub_key)

# Infos
info = {"state": env_state}
observations = {agent: np.asarray(
timestep.observation.agents_view[i], dtype=np.float32)
for i, agent in enumerate(self.possible_agents)
}
legals = {agent: np.asarray(
timestep.observation.action_mask[i], dtype=np.int32)
for i, agent in enumerate(self.possible_agents)
}

# Convert observations to OLT format
observations = self._convert_observations(observations, False)
# Infos
info = {"legals": legals}

return observations, info

def step(self, actions: Dict[str, np.ndarray]) -> StepReturn:
"""Steps in env."""
actions = ... # convert actions
actions = jnp.array(list(actions.values())).squeeze(-1)
# Step the environment
self._state, timestep = self._env.step(self._state, actions)

# Global state
env_state = self._create_state_representation(observations)
self._state, timestep = self._env_step(self._state, actions)

observations = {agent: np.asarray(
timestep.observation.agents_view[i], dtype=np.float32)
for i, agent in enumerate(self.possible_agents)
}
legals = {agent: np.asarray(
timestep.observation.action_mask[i], dtype=np.int32)
for i, agent in enumerate(self.possible_agents)
}
rewards = {agent: timestep.reward for agent in self.possible_agents}
terminals = {agent: timestep.last() for agent in self.possible_agents}
truncations = {agent: False for agent in self.possible_agents}

# # Global state
# env_state = self._create_state_representation(observations)

# Extra infos
info = {"state": env_state}
info = {"legals": legals}

return observations, rewards, terminals, truncations, info
4 changes: 4 additions & 0 deletions og_marl/environments/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,5 +43,9 @@ def get_environment(env_name: str, scenario: str) -> BaseEnvironment:
from og_marl.environments.voltage_control import VoltageControlEnv

return VoltageControlEnv()
elif env_name == "rware":
from og_marl.environments.jumanji_rware import JumanjiRware

return JumanjiRware(scenario)
else:
raise ValueError("Environment not recognised.")
4 changes: 2 additions & 2 deletions og_marl/offline_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,12 +299,12 @@ def download_and_unzip_vault(
scenario_name: str,
dataset_base_dir: str = "./vaults",
) -> None:
dataset_download_url = VAULT_INFO[env_name][scenario_name]["url"]

if check_directory_exists_and_not_empty(f"{dataset_base_dir}/{env_name}/{scenario_name}.vlt"):
print(f"Vault '{dataset_base_dir}/{env_name}/{scenario_name}' already exists.")
return

dataset_download_url = VAULT_INFO[env_name][scenario_name]["url"]

os.makedirs(f"{dataset_base_dir}/tmp/", exist_ok=True)
os.makedirs(f"{dataset_base_dir}/{env_name}/", exist_ok=True)

Expand Down
5 changes: 3 additions & 2 deletions og_marl/tf2/systems/bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,11 @@ def _tf_select_actions(
probs = tf.nn.softmax(logits)

if legal_actions is not None:
agent_legals = tf.expand_dims(legal_actions[agent], axis=0)
agent_legals = tf.cast(tf.expand_dims(legal_actions[agent], axis=0), "float32")
probs = (probs * agent_legals) / tf.reduce_sum(
probs * agent_legals
) # mask and renorm
probs = tf.cast(probs, "float32")

action = tfp.distributions.Categorical(probs=probs).sample(1)

Expand All @@ -124,7 +125,7 @@ def train_step(self, experience: Experience) -> Dict[str, Numeric]:
logs = self._tf_train_step(experience)
return logs # type: ignore

@tf.function(jit_compile=True)
@tf.function()#jit_compile=True)
def _tf_train_step(self, experience: Dict[str, Any]) -> Dict[str, Numeric]:
# Unpack the relevant quantities
observations = experience["observations"]
Expand Down
4 changes: 2 additions & 2 deletions og_marl/tf2/systems/idrqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def select_actions(
lambda x: x.numpy(), actions
) # convert to numpy and squeeze batch dim

@tf.function(jit_compile=True)
#@tf.function(jit_compile=True)
def _tf_select_actions(
self,
env_step_ctr: int,
Expand Down Expand Up @@ -149,7 +149,7 @@ def _tf_select_actions(
epsilon = tf.maximum(1.0 - self._eps_dec * env_step_ctr, self._eps_min)

greedy_probs = tf.one_hot(greedy_action, masked_q_values.shape[-1])
explore_probs = agent_legal_actions / tf.reduce_sum(agent_legal_actions)
explore_probs = tf.cast(agent_legal_actions / tf.reduce_sum(agent_legal_actions), dtype=tf.float32)
probs = (1.0 - epsilon) * greedy_probs + epsilon * explore_probs
probs = tf.expand_dims(probs, axis=0)

Expand Down
148 changes: 148 additions & 0 deletions vault_conversion.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loading vault found at /Users/callum/og-marl/vaults/rware/tiny-4ag.vlt/Replay\n"
]
}
],
"source": [
"from flashbax.vault import Vault\n",
"vlt = Vault(\n",
" vault_name=\"rware/tiny-4ag.vlt\",\n",
" vault_uid=\"Replay\",\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"data = vlt.read()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"dict_keys(['action', 'done', 'legal_action_mask', 'observation', 'reward'])"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data.experience.keys()"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"import jax.numpy as jnp\n",
"\n",
"def convert_timestep(input_dict):\n",
" output_dict = {\n",
" \"observations\": jnp.asarray(input_dict[\"observation\"][..., 4:], dtype=jnp.float32),\n",
" \"rewards\": input_dict[\"reward\"],\n",
" \"actions\": jnp.asarray(input_dict[\"action\"], dtype=jnp.int32),\n",
" \"terminals\": input_dict[\"done\"],\n",
" \"truncations\": jnp.zeros_like(input_dict[\"done\"], dtype=jnp.bool_),\n",
" \"infos\": {\n",
" \"legals\": jnp.asarray(input_dict[\"legal_action_mask\"], dtype=jnp.int32),\n",
" \"env_state\": input_dict[\"observation\"][..., 4:].reshape(*input_dict[\"observation\"].shape[:2], -1)\n",
" }\n",
" }\n",
" return output_dict"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"New vault created at /Users/callum/og-marl/vaults/rware/tiny-4ag-converted.vlt/Replay\n",
"Since the provided buffer state has a temporal dimension of 6256, you must write to the vault at least every 6255 timesteps to avoid data loss.\n"
]
}
],
"source": [
"new_vault = Vault(\n",
" vault_name=\"rware/tiny-4ag-converted.vlt\",\n",
" vault_uid=\"Replay\",\n",
" experience_structure=convert_timestep(data.experience),\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"6256"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from flashbax.buffers.trajectory_buffer import TrajectoryBufferState\n",
"import jax\n",
"new_vault.write(\n",
" TrajectoryBufferState(\n",
" experience=convert_timestep(data.experience),\n",
" current_index=jax.tree_util.tree_flatten(data.experience)[0][0].shape[1],\n",
" is_full=False,\n",
" )\n",
")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "ogmarl",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.18"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

0 comments on commit fb3d307

Please sign in to comment.