Skip to content

Commit

Permalink
chore: mkdocs for all envs, with a few todos.
Browse files Browse the repository at this point in the history
  • Loading branch information
callumtilbury committed Mar 11, 2024
1 parent 1fabc48 commit 6ad688f
Show file tree
Hide file tree
Showing 15 changed files with 480 additions and 53 deletions.
12 changes: 12 additions & 0 deletions og_marl/environments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,18 @@


def get_environment(env_name: str, scenario: str) -> BaseEnvironment:
"""Gets the environment object, given the environment name and scenario.
Args:
env_name (str): name of environment (e.g. smac_v1)
scenario (str): scenario name (e.g. 3m)
Raises:
ValueError: Unrecognised environment.
Returns:
BaseEnvironment: Environment object.
"""
if env_name == "smac_v1":
from og_marl.environments.smacv1 import SMACv1

Expand Down
32 changes: 23 additions & 9 deletions og_marl/environments/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,32 +38,46 @@ def __init__(self) -> None:
pass

def reset(self) -> ResetReturn:
"""Resets the environment.
Raises:
NotImplementedError: Abstract class.
Returns:
ResetReturn: the initial observations and info.
"""
raise NotImplementedError

def step(self, actions: Dict[str, np.ndarray]) -> StepReturn:
"""Steps the environment.
Args:
actions (Dict[str, np.ndarray]): Actions taken by the agents.
Raises:
NotImplementedError: Abstract class.
Returns:
StepReturn: the next observations, rewards, terminals, truncations, and info.
"""
raise NotImplementedError

def get_stats(self) -> Dict:
"""Return extra stats to be logged.
"""Returns any extra stats to be logged.
Returns:
-------
extra stats to be logged.
Dict: extra stats to be logged.
"""
return {}

def __getattr__(self, name: str) -> Any:
"""Expose any other attributes of the underlying environment.
Args:
----
name (str): attribute.
name (str): Name of the attribute.
Returns:
-------
Any: return attribute from env or underlying env.
Any: The attribute.
"""
if hasattr(self.__class__, name):
return self.__getattribute__(name)
Expand Down
39 changes: 39 additions & 0 deletions og_marl/environments/flatland_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,14 @@


class Flatland(BaseEnvironment):
"""Wrapper for Flatland."""

def __init__(self, map_name: str = "5_trains"):
"""Constructor.
Args:
map_name (str, optional): name of scenario in Flatland. Defaults to "5_trains".
"""
map_config = FLATLAND_MAP_CONFIGS[map_name]

self._num_actions = 5
Expand Down Expand Up @@ -91,6 +98,11 @@ def __init__(self, map_name: str = "5_trains"):
self.max_episode_length = map_config["max_episode_len"]

def reset(self) -> ResetReturn:
"""Resets the environment.
Returns:
ResetReturn: the initial observations and info.
"""
self._done = False

observations, info = self._environment.reset()
Expand All @@ -106,6 +118,14 @@ def reset(self) -> ResetReturn:
return observations, info

def step(self, actions: Dict[str, np.ndarray]) -> StepReturn:
"""Steps the environment.
Args:
actions (Dict[str, np.ndarray]): Actions taken by the agents.
Returns:
StepReturn: the next observations, rewards, terminals, truncations, and info.
"""
actions = {int(agent): action.item() for agent, action in actions.items()}

# Step the Flatland environment
Expand Down Expand Up @@ -137,6 +157,11 @@ def step(self, actions: Dict[str, np.ndarray]) -> StepReturn:
return next_observations, rewards, terminals, truncations, info

def _get_legal_actions(self) -> Dict[str, np.ndarray]:
"""Computes the legal actions for each agent.
Returns:
Dict[str, np.ndarray]: legal actions for each agent.
"""
legal_actions = {}
for agent in self.possible_agents:
agent_id = int(agent)
Expand All @@ -155,6 +180,11 @@ def _get_legal_actions(self) -> Dict[str, np.ndarray]:
return legal_actions

def _make_state_representation(self) -> np.ndarray:
"""Creates the state representation.
Returns:
np.ndarray: state representation.
"""
state = []
for i, _ in enumerate(self.possible_agents):
agent = self._environment.agents[i]
Expand All @@ -179,6 +209,15 @@ def _convert_observations(
observations: Dict[int, np.ndarray],
info: Dict[str, Dict[int, np.ndarray]],
) -> Observations:
"""TODO
Args:
observations (Dict[int, np.ndarray]):
info (Dict[str, Dict[int, np.ndarray]]):
Returns:
Observations: _description_
"""
new_observations = {}
for i, agent in enumerate(self.possible_agents):
agent_id = i
Expand Down
36 changes: 34 additions & 2 deletions og_marl/environments/gymnasium_mamujoco.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,14 @@


def get_env_config(scenario: str) -> Dict[str, Any]:
"""Helper method to get env_args."""
"""Gets the environment configuration, given the scenario.
Args:
scenario (str): scenario name.
Returns:
Dict[str, Any]: environment configuration, comprising the scenario and agent configuration.
"""
env_args: Dict[str, Any] = {
"agent_obsk": 1,
}
Expand Down Expand Up @@ -54,27 +61,52 @@ class MAMuJoCo:
"""Environment wrapper Multi-Agent MuJoCo."""

def __init__(self, scenario: str):
"""Constructor.
Args:
scenario (str): scenario name.
"""
env_config = get_env_config(scenario)
self._environment = gymnasium_robotics.mamujoco_v0.parallel_env(**env_config)

self.info_spec = {"state": self._environment.state()}

def reset(self) -> ResetReturn:
"""Resets the environment.
Returns:
ResetReturn: the initial observations and info.
"""
observations, _ = self._environment.reset()

info = {"state": self._environment.state().astype("float32")}

return observations, info

def step(self, actions: Dict[str, np.ndarray]) -> StepReturn:
"""Steps the environment.
Args:
actions (Dict[str, np.ndarray]): Actions taken by the agents.
Returns:
StepReturn: the next observations, rewards, terminals, truncations, and info.
"""
observations, rewards, terminals, trunctations, _ = self._environment.step(actions)

info = {"state": self._environment.state().astype("float32")}

return observations, rewards, terminals, trunctations, info

def __getattr__(self, name: str) -> Any:
"""Expose any other attributes of the underlying environment."""
"""Exposes attributes of the underlying environment.
Args:
name (str): name of the attribute.
Returns:
Any: the attribute.
"""
if hasattr(self.__class__, name):
return self.__getattribute__(name)
else:
Expand Down
25 changes: 21 additions & 4 deletions og_marl/environments/jaxmarl_smax.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,21 @@
import numpy as np
from jaxmarl import make
from jaxmarl.environments.smax import map_name_to_scenario

from og_marl.environments.base import BaseEnvironment, ResetReturn, StepReturn


class SMAX(BaseEnvironment):

"""Environment wrapper for Jumanji environments."""
"""Environment wrapper for SMAX environments from JaxMARL."""

def __init__(self, scenario_name: str = "3m", seed: int = 0) -> None:
"""Constructor."""
"""Constructor.
Args:
scenario_name (str, optional): name of scenario in SMAX. Defaults to "3m".
seed (int, optional): random seed initialisation. Defaults to 0.
"""
scenario = map_name_to_scenario(scenario_name)

self._environment = make(
Expand All @@ -51,7 +57,11 @@ def __init__(self, scenario_name: str = "3m", seed: int = 0) -> None:
self._env_step = jax.jit(self._environment.step)

def reset(self) -> ResetReturn:
"""Resets the env."""
"""Resets the environment.
Returns:
ResetReturn: the initial observations and info.
"""
# Reset the environment
self._key, sub_key = jax.random.split(self._key)
obs, self._state = self._environment.reset(sub_key)
Expand All @@ -71,7 +81,14 @@ def reset(self) -> ResetReturn:
return observations, info

def step(self, actions: Dict[str, np.ndarray]) -> StepReturn:
"""Steps in env."""
"""Steps the environment.
Args:
actions (Dict[str, np.ndarray]): Actions taken by the agents.
Returns:
StepReturn: the next observations, rewards, terminals, truncations, and info.
"""
self._key, sub_key = jax.random.split(self._key)

# Step the environment
Expand Down
26 changes: 21 additions & 5 deletions og_marl/environments/jumanji_lbf.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,15 @@

class JumanjiLBF(BaseEnvironment):

"""Environment wrapper for Jumanji environments."""
"""Environment wrapper for Jumanji's Level-based Foraging environment."""

def __init__(self, scenario_name: str = "2s-8x8-2p-2f-coop", seed: int = 0) -> None:
"""Constructor."""
"""Constructor.
Args:
scenario_name (str, optional): name of scenario in LBF. Defaults to "2s-8x8-2p-2f-coop".
seed (int, optional): random seed initialisation. Defaults to 0.
"""
self._environment = jumanji.make(
"LevelBasedForaging-v0",
time_limit=100,
Expand All @@ -65,8 +70,12 @@ def __init__(self, scenario_name: str = "2s-8x8-2p-2f-coop", seed: int = 0) -> N
self._env_step = jax.jit(self._environment.step)

def reset(self) -> ResetReturn:
"""Resets the env."""
# Reset the environment
"""Resets the environment.
Returns:
ResetReturn: the initial observations and info.
"""
# Resets the underlying environment
self._key, sub_key = jax.random.split(self._key)
self._state, timestep = self._environment.reset(sub_key)

Expand All @@ -85,7 +94,14 @@ def reset(self) -> ResetReturn:
return observations, info

def step(self, actions: Dict[str, np.ndarray]) -> StepReturn:
"""Steps in env."""
"""Steps the environment.
Args:
actions (Dict[str, np.ndarray]): Actions taken by the agents.
Returns:
StepReturn: the next observations, rewards, terminals, truncations, and info.
"""
actions = jnp.array([actions[agent] for agent in self.possible_agents])
# Step the environment
self._state, timestep = self._env_step(self._state, actions)
Expand Down
25 changes: 21 additions & 4 deletions og_marl/environments/jumanji_rware.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,16 @@

class JumanjiRware(BaseEnvironment):

"""Environment wrapper for Jumanji environments."""
"""Environment wrapper for Jumanji's Robot-Warehouse environments."""

def __init__(self, scenario_name: str = "tiny-4ag", seed: int = 0) -> None:
"""Constructor."""
"""Constructor.
Args:
scenario_name (str, optional):
name of scenario in Robot-Warehouse. Defaults to "tiny-4ag".
seed (int, optional): random seed initialisation. Defaults to 0.
"""
self._environment = jumanji.make(
"RobotWarehouse-v0",
generator=RandomGenerator(**task_configs[scenario_name]),
Expand All @@ -64,7 +70,11 @@ def __init__(self, scenario_name: str = "tiny-4ag", seed: int = 0) -> None:
self._env_step = jax.jit(self._environment.step, donate_argnums=0)

def reset(self) -> ResetReturn:
"""Resets the env."""
"""Resets the environment.
Returns:
ResetReturn: the initial observations and info.
"""
# Reset the environment
self._key, sub_key = jax.random.split(self._key)
self._state, timestep = self._environment.reset(sub_key)
Expand All @@ -84,7 +94,14 @@ def reset(self) -> ResetReturn:
return observations, info

def step(self, actions: Dict[str, np.ndarray]) -> StepReturn:
"""Steps in env."""
"""Steps the environment.
Args:
actions (Dict[str, np.ndarray]): Actions taken by the agents.
Returns:
StepReturn: the next observations, rewards, terminals, truncations, and info.
"""
actions = jnp.array(list(actions.values())) # .squeeze(-1)
# Step the environment
self._state, timestep = self._env_step(self._state, actions)
Expand Down
Loading

0 comments on commit 6ad688f

Please sign in to comment.