Skip to content

Commit

Permalink
Merge main.
Browse files Browse the repository at this point in the history
  • Loading branch information
jcformanek committed Mar 8, 2024
2 parents 4e15468 + bded235 commit e811629
Show file tree
Hide file tree
Showing 7 changed files with 420 additions and 56 deletions.
28 changes: 28 additions & 0 deletions examples/tf2/online/idrqn_smax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Copyright 2023 InstaDeep Ltd. All rights reserved.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from og_marl.environments.jaxmarl_smax import SMAX
from og_marl.loggers import WandbLogger
from og_marl.replay_buffers import FlashbaxReplayBuffer
from og_marl.tf2.systems.qmix import QMIXSystem

env = SMAX("3m")

logger = WandbLogger(entity="claude_formanek")

system = QMIXSystem(env, logger, eps_decay_timesteps=50_000)

replay_buffer = FlashbaxReplayBuffer(sequence_length=20)

system.train_online(replay_buffer)
59 changes: 59 additions & 0 deletions og_marl/environments/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# type: ignore

# Copyright 2023 InstaDeep Ltd. All rights reserved.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from og_marl.environments.base import BaseEnvironment


def get_environment(env_name: str, scenario: str) -> BaseEnvironment:
if env_name == "smac_v1":
from og_marl.environments.smacv1 import SMACv1

return SMACv1(scenario)
elif env_name == "smac_v2":
from og_marl.environments.smacv2 import SMACv2

return SMACv2(scenario)
elif env_name == "mamujoco":
from og_marl.environments.old_mamujoco import MAMuJoCo

return MAMuJoCo(scenario)
elif env_name == "gymnasium_mamujoco":
from og_marl.environments.gymnasium_mamujoco import MAMuJoCo

return MAMuJoCo(scenario)
elif env_name == "flatland":
from og_marl.environments.flatland_wrapper import Flatland

return Flatland(scenario)
elif env_name == "voltage_control":
from og_marl.environments.voltage_control import VoltageControlEnv

return VoltageControlEnv()
elif env_name == "smax":
from og_marl.environments.jaxmarl_smax import SMAX

return SMAX(scenario)
elif env_name == "lbf":
from og_marl.environments.jumanji_lbf import JumanjiLBF

return JumanjiLBF(scenario)
elif env_name == "rware":
from og_marl.environments.jumanji_rware import JumanjiRware

return JumanjiRware(scenario)
else:
raise ValueError("Environment not recognised.")
98 changes: 98 additions & 0 deletions og_marl/environments/jaxmarl_smax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# Copyright 2023 InstaDeep Ltd. All rights reserved.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Base wrapper for Jumanji environments."""
from typing import Any, Dict

import jax
import numpy as np
from jaxmarl import make
from jaxmarl.environments.smax import map_name_to_scenario
from og_marl.environments.base import BaseEnvironment, ResetReturn, StepReturn


class SMAX(BaseEnvironment):

"""Environment wrapper for Jumanji environments."""

def __init__(self, scenario_name: str = "3m", seed: int = 0) -> None:
"""Constructor."""
scenario = map_name_to_scenario(scenario_name)

self._environment = make(
"HeuristicEnemySMAX",
enemy_shoots=True,
scenario=scenario,
use_self_play_reward=False,
walls_cause_death=True,
see_enemy_actions=False,
)

self._num_agents = self._environment.num_agents
self.possible_agents = self._environment.agents
self._num_actions = int(self._environment.action_spaces[self.possible_agents[0]].n)

self._state = ... # Jaxmarl environment state

self.info_spec: Dict[str, Any] = {} # TODO add global state spec

self._key = jax.random.PRNGKey(seed)

self._env_step = jax.jit(self._environment.step)

def reset(self) -> ResetReturn:
"""Resets the env."""
# Reset the environment
self._key, sub_key = jax.random.split(self._key)
obs, self._state = self._environment.reset(sub_key)

observations = {
agent: np.asarray(obs[agent], dtype=np.float32) for agent in self.possible_agents
}
legals = {
agent: np.array(legal, "int64")
for agent, legal in self._environment.get_avail_actions(self._state).items()
}
state = np.asarray(obs["world_state"], "float32")

# Infos
info = {"legals": legals, "state": state}

return observations, info

def step(self, actions: Dict[str, np.ndarray]) -> StepReturn:
"""Steps in env."""
self._key, sub_key = jax.random.split(self._key)

# Step the environment
obs, self._state, reward, done, infos = self._environment.step(
sub_key, self._state, actions
)

observations = {
agent: np.asarray(obs[agent], dtype=np.float32) for agent in self.possible_agents
}
legals = {
agent: np.array(legal, "int64")
for agent, legal in self._environment.get_avail_actions(self._state).items()
}
state = np.asarray(obs["world_state"], "float32")

# Infos
info = {"legals": legals, "state": state}

rewards = {agent: reward[agent] for agent in self.possible_agents}
terminals = {agent: done["__all__"] for agent in self.possible_agents}
truncations = {agent: False for agent in self.possible_agents}

return observations, rewards, terminals, truncations, info
111 changes: 111 additions & 0 deletions og_marl/environments/jumanji_lbf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# Copyright 2023 InstaDeep Ltd. All rights reserved.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Base wrapper for Jumanji LBF."""
from typing import Any, Dict

import jax
import jax.numpy as jnp
import jumanji
import numpy as np
from jumanji.environments.routing.lbf.generator import RandomGenerator

from og_marl.environments.base import BaseEnvironment, ResetReturn, StepReturn

task_configs = {
"8x8-2p-2f-coop": {
"grid_size": 8, # size of the grid to generate.
"fov": 8, # field of view of an agent.
"num_agents": 2, # number of agents on the grid.
"num_food": 2, # number of food in the environment.
"max_agent_level": 2, # maximum level of the agents (inclusive).
"force_coop": True, # force cooperation between agents.
},
"15x15-4p-5f": {
"grid_size": 15, # size of the grid to generate.
"fov": 15, # field of view of an agent.
"num_agents": 4, # number of agents on the grid.
"num_food": 5, # number of food in the environment.
"max_agent_level": 2, # maximum level of the agents (inclusive).
"force_coop": False, # force cooperation between agents.
},
}


class JumanjiLBF(BaseEnvironment):

"""Environment wrapper for Jumanji environments."""

def __init__(self, scenario_name: str = "2s-8x8-2p-2f-coop", seed: int = 0) -> None:
"""Constructor."""
self._environment = jumanji.make(
"LevelBasedForaging-v0",
time_limit=100,
generator=RandomGenerator(**task_configs[scenario_name]),
)
self._num_agents = self._environment.num_agents
self._num_actions = int(self._environment.action_spec().num_values[0])
self.possible_agents = [f"agent_{i}" for i in range(self._num_agents)]
self._state = ... # Jumanji environment state

self.info_spec: Dict[str, Any] = {} # TODO add global state spec

self._key = jax.random.PRNGKey(seed)

self._env_step = jax.jit(self._environment.step)

def reset(self) -> ResetReturn:
"""Resets the env."""
# Reset the environment
self._key, sub_key = jax.random.split(self._key)
self._state, timestep = self._environment.reset(sub_key)

observations = {
agent: np.asarray(timestep.observation.agents_view[i], dtype=np.float32)
for i, agent in enumerate(self.possible_agents)
}
legals = {
agent: np.asarray(timestep.observation.action_mask[i], dtype=np.int32)
for i, agent in enumerate(self.possible_agents)
}

# Infos
info = {"legals": legals}

return observations, info

def step(self, actions: Dict[str, np.ndarray]) -> StepReturn:
"""Steps in env."""
actions = jnp.array([actions[agent] for agent in self.possible_agents])
# Step the environment
self._state, timestep = self._env_step(self._state, actions)

observations = {
agent: np.asarray(timestep.observation.agents_view[i], dtype=np.float32)
for i, agent in enumerate(self.possible_agents)
}
legals = {
agent: np.asarray(timestep.observation.action_mask[i], dtype=np.int32)
for i, agent in enumerate(self.possible_agents)
}
rewards = {agent: np.asarray(timestep.reward) for agent in self.possible_agents}
terminals = {agent: np.asarray(timestep.last()) for agent in self.possible_agents}
truncations = {agent: np.asarray(False) for agent in self.possible_agents}

# # Global state # TODO
# env_state = self._create_state_representation(observations)

# Extra infos
info = {"legals": legals}

return observations, rewards, terminals, truncations, info
Loading

0 comments on commit e811629

Please sign in to comment.