From 6ad688f9f97a0711d2e8af4080f96ac36de0e7a0 Mon Sep 17 00:00:00 2001 From: Callum Tilbury Date: Mon, 11 Mar 2024 15:24:12 +0200 Subject: [PATCH] chore: mkdocs for all envs, with a few todos. --- og_marl/environments/__init__.py | 12 +++ og_marl/environments/base.py | 32 +++++-- og_marl/environments/flatland_wrapper.py | 39 ++++++++ og_marl/environments/gymnasium_mamujoco.py | 36 +++++++- og_marl/environments/jaxmarl_smax.py | 25 ++++- og_marl/environments/jumanji_lbf.py | 26 +++++- og_marl/environments/jumanji_rware.py | 25 ++++- og_marl/environments/old_mamujoco.py | 38 +++++++- og_marl/environments/pettingzoo_base.py | 50 +++++++++- og_marl/environments/pistonball.py | 39 ++++++-- og_marl/environments/pursuit.py | 20 +++- og_marl/environments/smacv1.py | 26 +++++- og_marl/environments/smacv2.py | 28 +++++- og_marl/environments/voltage_control.py | 35 ++++++- og_marl/environments/wrappers.py | 102 ++++++++++++++++++++- 15 files changed, 480 insertions(+), 53 deletions(-) diff --git a/og_marl/environments/__init__.py b/og_marl/environments/__init__.py index c1439ff2..f13ff78c 100644 --- a/og_marl/environments/__init__.py +++ b/og_marl/environments/__init__.py @@ -19,6 +19,18 @@ def get_environment(env_name: str, scenario: str) -> BaseEnvironment: + """Gets the environment object, given the environment name and scenario. + + Args: + env_name (str): name of environment (e.g. smac_v1) + scenario (str): scenario name (e.g. 3m) + + Raises: + ValueError: Unrecognised environment. + + Returns: + BaseEnvironment: Environment object. + """ if env_name == "smac_v1": from og_marl.environments.smacv1 import SMACv1 diff --git a/og_marl/environments/base.py b/og_marl/environments/base.py index 2789eee8..fc1469a4 100644 --- a/og_marl/environments/base.py +++ b/og_marl/environments/base.py @@ -38,18 +38,35 @@ def __init__(self) -> None: pass def reset(self) -> ResetReturn: + """Resets the environment. + + Raises: + NotImplementedError: Abstract class. + + Returns: + ResetReturn: the initial observations and info. + """ raise NotImplementedError def step(self, actions: Dict[str, np.ndarray]) -> StepReturn: + """Steps the environment. + + Args: + actions (Dict[str, np.ndarray]): Actions taken by the agents. + + Raises: + NotImplementedError: Abstract class. + + Returns: + StepReturn: the next observations, rewards, terminals, truncations, and info. + """ raise NotImplementedError def get_stats(self) -> Dict: - """Return extra stats to be logged. + """Returns any extra stats to be logged. Returns: - ------- - extra stats to be logged. - + Dict: extra stats to be logged. """ return {} @@ -57,13 +74,10 @@ def __getattr__(self, name: str) -> Any: """Expose any other attributes of the underlying environment. Args: - ---- - name (str): attribute. + name (str): Name of the attribute. Returns: - ------- - Any: return attribute from env or underlying env. - + Any: The attribute. """ if hasattr(self.__class__, name): return self.__getattribute__(name) diff --git a/og_marl/environments/flatland_wrapper.py b/og_marl/environments/flatland_wrapper.py index f6e5dc06..8cb51f83 100644 --- a/og_marl/environments/flatland_wrapper.py +++ b/og_marl/environments/flatland_wrapper.py @@ -45,7 +45,14 @@ class Flatland(BaseEnvironment): + """Wrapper for Flatland.""" + def __init__(self, map_name: str = "5_trains"): + """Constructor. + + Args: + map_name (str, optional): name of scenario in Flatland. Defaults to "5_trains". + """ map_config = FLATLAND_MAP_CONFIGS[map_name] self._num_actions = 5 @@ -91,6 +98,11 @@ def __init__(self, map_name: str = "5_trains"): self.max_episode_length = map_config["max_episode_len"] def reset(self) -> ResetReturn: + """Resets the environment. + + Returns: + ResetReturn: the initial observations and info. + """ self._done = False observations, info = self._environment.reset() @@ -106,6 +118,14 @@ def reset(self) -> ResetReturn: return observations, info def step(self, actions: Dict[str, np.ndarray]) -> StepReturn: + """Steps the environment. + + Args: + actions (Dict[str, np.ndarray]): Actions taken by the agents. + + Returns: + StepReturn: the next observations, rewards, terminals, truncations, and info. + """ actions = {int(agent): action.item() for agent, action in actions.items()} # Step the Flatland environment @@ -137,6 +157,11 @@ def step(self, actions: Dict[str, np.ndarray]) -> StepReturn: return next_observations, rewards, terminals, truncations, info def _get_legal_actions(self) -> Dict[str, np.ndarray]: + """Computes the legal actions for each agent. + + Returns: + Dict[str, np.ndarray]: legal actions for each agent. + """ legal_actions = {} for agent in self.possible_agents: agent_id = int(agent) @@ -155,6 +180,11 @@ def _get_legal_actions(self) -> Dict[str, np.ndarray]: return legal_actions def _make_state_representation(self) -> np.ndarray: + """Creates the state representation. + + Returns: + np.ndarray: state representation. + """ state = [] for i, _ in enumerate(self.possible_agents): agent = self._environment.agents[i] @@ -179,6 +209,15 @@ def _convert_observations( observations: Dict[int, np.ndarray], info: Dict[str, Dict[int, np.ndarray]], ) -> Observations: + """TODO + + Args: + observations (Dict[int, np.ndarray]): + info (Dict[str, Dict[int, np.ndarray]]): + + Returns: + Observations: _description_ + """ new_observations = {} for i, agent in enumerate(self.possible_agents): agent_id = i diff --git a/og_marl/environments/gymnasium_mamujoco.py b/og_marl/environments/gymnasium_mamujoco.py index 7e1aa687..7f8b211d 100644 --- a/og_marl/environments/gymnasium_mamujoco.py +++ b/og_marl/environments/gymnasium_mamujoco.py @@ -21,7 +21,14 @@ def get_env_config(scenario: str) -> Dict[str, Any]: - """Helper method to get env_args.""" + """Gets the environment configuration, given the scenario. + + Args: + scenario (str): scenario name. + + Returns: + Dict[str, Any]: environment configuration, comprising the scenario and agent configuration. + """ env_args: Dict[str, Any] = { "agent_obsk": 1, } @@ -54,12 +61,22 @@ class MAMuJoCo: """Environment wrapper Multi-Agent MuJoCo.""" def __init__(self, scenario: str): + """Constructor. + + Args: + scenario (str): scenario name. + """ env_config = get_env_config(scenario) self._environment = gymnasium_robotics.mamujoco_v0.parallel_env(**env_config) self.info_spec = {"state": self._environment.state()} def reset(self) -> ResetReturn: + """Resets the environment. + + Returns: + ResetReturn: the initial observations and info. + """ observations, _ = self._environment.reset() info = {"state": self._environment.state().astype("float32")} @@ -67,6 +84,14 @@ def reset(self) -> ResetReturn: return observations, info def step(self, actions: Dict[str, np.ndarray]) -> StepReturn: + """Steps the environment. + + Args: + actions (Dict[str, np.ndarray]): Actions taken by the agents. + + Returns: + StepReturn: the next observations, rewards, terminals, truncations, and info. + """ observations, rewards, terminals, trunctations, _ = self._environment.step(actions) info = {"state": self._environment.state().astype("float32")} @@ -74,7 +99,14 @@ def step(self, actions: Dict[str, np.ndarray]) -> StepReturn: return observations, rewards, terminals, trunctations, info def __getattr__(self, name: str) -> Any: - """Expose any other attributes of the underlying environment.""" + """Exposes attributes of the underlying environment. + + Args: + name (str): name of the attribute. + + Returns: + Any: the attribute. + """ if hasattr(self.__class__, name): return self.__getattribute__(name) else: diff --git a/og_marl/environments/jaxmarl_smax.py b/og_marl/environments/jaxmarl_smax.py index 29f05fd6..ddeddf5b 100644 --- a/og_marl/environments/jaxmarl_smax.py +++ b/og_marl/environments/jaxmarl_smax.py @@ -18,15 +18,21 @@ import numpy as np from jaxmarl import make from jaxmarl.environments.smax import map_name_to_scenario + from og_marl.environments.base import BaseEnvironment, ResetReturn, StepReturn class SMAX(BaseEnvironment): - """Environment wrapper for Jumanji environments.""" + """Environment wrapper for SMAX environments from JaxMARL.""" def __init__(self, scenario_name: str = "3m", seed: int = 0) -> None: - """Constructor.""" + """Constructor. + + Args: + scenario_name (str, optional): name of scenario in SMAX. Defaults to "3m". + seed (int, optional): random seed initialisation. Defaults to 0. + """ scenario = map_name_to_scenario(scenario_name) self._environment = make( @@ -51,7 +57,11 @@ def __init__(self, scenario_name: str = "3m", seed: int = 0) -> None: self._env_step = jax.jit(self._environment.step) def reset(self) -> ResetReturn: - """Resets the env.""" + """Resets the environment. + + Returns: + ResetReturn: the initial observations and info. + """ # Reset the environment self._key, sub_key = jax.random.split(self._key) obs, self._state = self._environment.reset(sub_key) @@ -71,7 +81,14 @@ def reset(self) -> ResetReturn: return observations, info def step(self, actions: Dict[str, np.ndarray]) -> StepReturn: - """Steps in env.""" + """Steps the environment. + + Args: + actions (Dict[str, np.ndarray]): Actions taken by the agents. + + Returns: + StepReturn: the next observations, rewards, terminals, truncations, and info. + """ self._key, sub_key = jax.random.split(self._key) # Step the environment diff --git a/og_marl/environments/jumanji_lbf.py b/og_marl/environments/jumanji_lbf.py index f74da913..97d94b97 100644 --- a/og_marl/environments/jumanji_lbf.py +++ b/og_marl/environments/jumanji_lbf.py @@ -44,10 +44,15 @@ class JumanjiLBF(BaseEnvironment): - """Environment wrapper for Jumanji environments.""" + """Environment wrapper for Jumanji's Level-based Foraging environment.""" def __init__(self, scenario_name: str = "2s-8x8-2p-2f-coop", seed: int = 0) -> None: - """Constructor.""" + """Constructor. + + Args: + scenario_name (str, optional): name of scenario in LBF. Defaults to "2s-8x8-2p-2f-coop". + seed (int, optional): random seed initialisation. Defaults to 0. + """ self._environment = jumanji.make( "LevelBasedForaging-v0", time_limit=100, @@ -65,8 +70,12 @@ def __init__(self, scenario_name: str = "2s-8x8-2p-2f-coop", seed: int = 0) -> N self._env_step = jax.jit(self._environment.step) def reset(self) -> ResetReturn: - """Resets the env.""" - # Reset the environment + """Resets the environment. + + Returns: + ResetReturn: the initial observations and info. + """ + # Resets the underlying environment self._key, sub_key = jax.random.split(self._key) self._state, timestep = self._environment.reset(sub_key) @@ -85,7 +94,14 @@ def reset(self) -> ResetReturn: return observations, info def step(self, actions: Dict[str, np.ndarray]) -> StepReturn: - """Steps in env.""" + """Steps the environment. + + Args: + actions (Dict[str, np.ndarray]): Actions taken by the agents. + + Returns: + StepReturn: the next observations, rewards, terminals, truncations, and info. + """ actions = jnp.array([actions[agent] for agent in self.possible_agents]) # Step the environment self._state, timestep = self._env_step(self._state, actions) diff --git a/og_marl/environments/jumanji_rware.py b/og_marl/environments/jumanji_rware.py index 1998af88..6184876a 100644 --- a/og_marl/environments/jumanji_rware.py +++ b/og_marl/environments/jumanji_rware.py @@ -44,10 +44,16 @@ class JumanjiRware(BaseEnvironment): - """Environment wrapper for Jumanji environments.""" + """Environment wrapper for Jumanji's Robot-Warehouse environments.""" def __init__(self, scenario_name: str = "tiny-4ag", seed: int = 0) -> None: - """Constructor.""" + """Constructor. + + Args: + scenario_name (str, optional): + name of scenario in Robot-Warehouse. Defaults to "tiny-4ag". + seed (int, optional): random seed initialisation. Defaults to 0. + """ self._environment = jumanji.make( "RobotWarehouse-v0", generator=RandomGenerator(**task_configs[scenario_name]), @@ -64,7 +70,11 @@ def __init__(self, scenario_name: str = "tiny-4ag", seed: int = 0) -> None: self._env_step = jax.jit(self._environment.step, donate_argnums=0) def reset(self) -> ResetReturn: - """Resets the env.""" + """Resets the environment. + + Returns: + ResetReturn: the initial observations and info. + """ # Reset the environment self._key, sub_key = jax.random.split(self._key) self._state, timestep = self._environment.reset(sub_key) @@ -84,7 +94,14 @@ def reset(self) -> ResetReturn: return observations, info def step(self, actions: Dict[str, np.ndarray]) -> StepReturn: - """Steps in env.""" + """Steps the environment. + + Args: + actions (Dict[str, np.ndarray]): Actions taken by the agents. + + Returns: + StepReturn: the next observations, rewards, terminals, truncations, and info. + """ actions = jnp.array(list(actions.values())) # .squeeze(-1) # Step the environment self._state, timestep = self._env_step(self._state, actions) diff --git a/og_marl/environments/old_mamujoco.py b/og_marl/environments/old_mamujoco.py index c546f89e..ce88a77b 100644 --- a/og_marl/environments/old_mamujoco.py +++ b/og_marl/environments/old_mamujoco.py @@ -21,6 +21,17 @@ def get_mamujoco_args(scenario: str) -> Dict[str, Any]: + """Gets the environment configuration, given the scenario. + + Args: + scenario (str): scenario name. + + Raises: + ValueError: Not a valid mamujoco scenario. + + Returns: + Dict[str, Any]: environment configuration, comprising the scenario and agent configuration. + """ env_args = { "agent_obsk": 1, "episode_limit": 1000, @@ -45,6 +56,11 @@ class MAMuJoCo(BaseEnvironment): """Environment wrapper Multi-Agent MuJoCo.""" def __init__(self, scenario: str): + """Constructor. + + Args: + scenario (str): scenario name. + """ env_args = get_mamujoco_args(scenario) self._environment = MujocoMulti(env_args=env_args) @@ -71,6 +87,11 @@ def __init__(self, scenario: str): self.max_episode_length = 1000 def reset(self) -> ResetReturn: + """Resets the environment. + + Returns: + ResetReturn: the initial observations and info. + """ self._environment.reset() observations = self._environment.get_obs() @@ -84,6 +105,14 @@ def reset(self) -> ResetReturn: return observations, info def step(self, actions: Dict[str, np.ndarray]) -> StepReturn: + """Steps the environment. + + Args: + actions (Dict[str, np.ndarray]): Actions taken by the agents. + + Returns: + StepReturn: the next observations, rewards, terminals, truncations, and info. + """ mujoco_actions = [] for agent in self.possible_agents: mujoco_actions.append(actions[agent]) @@ -106,7 +135,14 @@ def step(self, actions: Dict[str, np.ndarray]) -> StepReturn: return observations, rewards, terminals, trunctations, info # type: ignore def __getattr__(self, name: str) -> Any: - """Expose any other attributes of the underlying environment.""" + """Exposes attributes of the underlying environment. + + Args: + name (str): name of the attribute. + + Returns: + Any: the attribute. + """ if hasattr(self.__class__, name): return self.__getattribute__(name) else: diff --git a/og_marl/environments/pettingzoo_base.py b/og_marl/environments/pettingzoo_base.py index edbfd13d..9b646d3a 100644 --- a/og_marl/environments/pettingzoo_base.py +++ b/og_marl/environments/pettingzoo_base.py @@ -28,8 +28,12 @@ def __init__(self) -> None: self.info_spec: Dict[str, Any] = {} def reset(self) -> ResetReturn: - """Resets the env.""" - # Reset the environment + """Resets the environment. + + Returns: + ResetReturn: the initial observations and info. + """ + # Reset the underlying environment observations = self._environment.reset() # type: ignore # Global state @@ -44,7 +48,14 @@ def reset(self) -> ResetReturn: return observations, info def step(self, actions: Dict[str, np.ndarray]) -> StepReturn: - """Steps in env.""" + """Steps the environment. + + Args: + actions (Dict[str, np.ndarray]): Actions taken by the agents. + + Returns: + StepReturn: the next observations, rewards, terminals, truncations and info. + """ # Step the environment observations, rewards, terminals, truncations, _ = self._environment.step(actions) @@ -57,6 +68,14 @@ def step(self, actions: Dict[str, np.ndarray]) -> StepReturn: return observations, rewards, terminals, truncations, info def _add_zero_obs_for_missing_agent(self, observations: Observations) -> Observations: + """TODO + + Args: + observations (Observations): _description_ + + Returns: + Observations: _description_ + """ for agent in self._agents: if agent not in observations: observations[agent] = np.zeros( @@ -68,9 +87,30 @@ def _add_zero_obs_for_missing_agent(self, observations: Observations) -> Observa def _convert_observations( self, observations: Dict[str, np.ndarray], done: bool ) -> Dict[str, np.ndarray]: - """Convert observations""" + """TODO + + Args: + observations (Dict[str, np.ndarray]): _description_ + done (bool): _description_ + + Raises: + NotImplementedError: _description_ + + Returns: + Dict[str, np.ndarray]: _description_ + """ raise NotImplementedError def _create_state_representation(self, observations: Dict[str, np.ndarray]) -> np.ndarray: - """Create global state representation from agent observations.""" + """Create global state representation from agent observations. + + Args: + observations (Dict[str, np.ndarray]): Observations from the agents. + + Raises: + NotImplementedError: Abstract class. + + Returns: + np.ndarray: Global state representation. + """ raise NotImplementedError diff --git a/og_marl/environments/pistonball.py b/og_marl/environments/pistonball.py index bc628c45..6da8c7cb 100644 --- a/og_marl/environments/pistonball.py +++ b/og_marl/environments/pistonball.py @@ -31,6 +31,11 @@ class Pistonball(PettingZooBase): """Environment wrapper for PettingZoo MARL environments.""" def __init__(self, n_pistons: int = 15): + """Constructor. + + Args: + n_pistons (int, optional): number of pistons. Defaults to 15. + """ self._environment = pistonball_v6.parallel_env( n_pistons=n_pistons, continuous=True, render_mode="rgb_array" ) @@ -50,6 +55,14 @@ def __init__(self, n_pistons: int = 15): self._done = False def _create_state_representation(self, observations: Dict[str, np.ndarray]) -> np.ndarray: + """Create state representation from observations. + + Args: + observations (Dict[str, np.ndarray]): observations from the environment. + + Returns: + np.ndarray: state representation. + """ if self._step_type == dm_env.StepType.FIRST: self._state_history = np.zeros((56, 88, 4), "float32") @@ -68,6 +81,15 @@ def _create_state_representation(self, observations: Dict[str, np.ndarray]) -> n def _convert_observations( self, observations: Dict[str, np.ndarray], done: bool ) -> Dict[str, np.ndarray]: + """TODO OLT deprecated? + + Args: + observations (Dict[str, np.ndarray]): _description_ + done (bool): _description_ + + Returns: + Dict[str, np.ndarray]: _description_ + """ olt_observations = {} for _, agent in enumerate(self._agents): agent_obs = np.expand_dims(observations[agent][50:, :], axis=-1) @@ -82,18 +104,21 @@ def _convert_observations( return olt_observations # type: ignore def extra_spec(self) -> Dict[str, specs.BoundedArray]: - """Function returns extra spec (format) of the env. + """Gets the extra spec of the env. Returns: - ------- - Dict[str, specs.BoundedArray]: extra spec. - + Dict[str, specs.BoundedArray]: extra spec of the env. """ state_spec = {"s_t": np.zeros((56, 88, 4), "float32")} # four stacked frames return state_spec def action_spec(self) -> Dict[str, specs.BoundedArray]: + """Gets the action spec of the env. + + Returns: + Dict[str, specs.BoundedArray]: action spec of the env for each agent. + """ action_spec = {} for agent in self._agents: spec = specs.BoundedArray( @@ -105,12 +130,10 @@ def action_spec(self) -> Dict[str, specs.BoundedArray]: return action_spec def observation_spec(self) -> Dict[str, OLT]: - """Observation spec. + """Gets the observation spec of the env. Returns: - ------- - types.Observation: spec for environment. - + Dict[str, OLT]: Observation spec. TODO OLT """ observation_specs = {} for agent in self._agents: diff --git a/og_marl/environments/pursuit.py b/og_marl/environments/pursuit.py index e6bb028e..aef0ca98 100644 --- a/og_marl/environments/pursuit.py +++ b/og_marl/environments/pursuit.py @@ -27,7 +27,7 @@ class Pursuit(PettingZooBase): """Environment wrapper for Pursuit.""" def __init__(self) -> None: - """Constructor for Pursuit""" + """Constructor.""" self._environment = black_death_v3(pursuit_v4.parallel_env()) self.possible_agents = self._environment.possible_agents self._num_actions = 5 @@ -41,10 +41,26 @@ def __init__(self) -> None: self.info_spec = {"state": np.zeros(8 * 2 + 30 * 2, "float32")} def _convert_observations(self, observations: Observations, done: bool) -> Observations: - """Convert observations.""" + """Convert observations to OLT format. TODO + + Args: + observations (Observations): _description_ + done (bool): _description_ + + Returns: + Observations: _description_ + """ return observations def _create_state_representation(self, observations: Observations) -> np.ndarray: + """Create state representation from observations. + + Args: + observations (Observations): Observations from the environment. + + Returns: + np.ndarray: State representation. + """ pursuer_pos = [ agent.current_position() for agent in self._environment.aec_env.env.env.env.pursuers ] diff --git a/og_marl/environments/smacv1.py b/og_marl/environments/smacv1.py index 46cb37c6..86867c27 100644 --- a/og_marl/environments/smacv1.py +++ b/og_marl/environments/smacv1.py @@ -30,6 +30,11 @@ def __init__( self, map_name: str, ): + """Constructor. + + Args: + map_name (str): _description_ + """ self._environment = StarCraft2Env(map_name=map_name) self.possible_agents = [f"agent_{n}" for n in range(self._environment.n_agents)] @@ -52,7 +57,11 @@ def __init__( self.max_episode_length = self._environment.episode_limit def reset(self) -> ResetReturn: - """Resets the env.""" + """Reset the environment. + + Returns: + ResetReturn: the initial observations and info. + """ # Reset the environment self._environment.reset() self._done = False @@ -71,7 +80,14 @@ def reset(self) -> ResetReturn: return observations, info def step(self, actions: Dict[str, np.ndarray]) -> StepReturn: - """Step in env.""" + """Step the environment. + + Args: + actions (Dict[str, np.ndarray]): _description_ + + Returns: + StepReturn: the next observations, rewards, terminals, truncations, and info. + """ # Convert dict of actions to list for SMAC smac_actions = [] for agent in self.possible_agents: @@ -100,7 +116,11 @@ def step(self, actions: Dict[str, np.ndarray]) -> StepReturn: return observations, rewards, terminals, truncations, info def _get_legal_actions(self) -> List[np.ndarray]: - """Get legal actions from the environment.""" + """Get the legal actions for each agent. + + Returns: + List[np.ndarray]: the legal actions for each agent. + """ legal_actions = [] for i, _ in enumerate(self.possible_agents): legal_actions.append( diff --git a/og_marl/environments/smacv2.py b/og_marl/environments/smacv2.py index c9d3038f..5b66ada7 100644 --- a/og_marl/environments/smacv2.py +++ b/og_marl/environments/smacv2.py @@ -90,6 +90,11 @@ class SMACv2(BaseEnvironment): """Environment wrapper SMAC.""" def __init__(self, scenario: str): + """Constructor. + + Args: + scenario (str): name of the SMACv2 scenario. + """ self._environment = StarCraftCapabilityEnvWrapper( capability_config=DISTRIBUTION_CONFIGS[scenario], map_name=MAP_NAMES[scenario], @@ -121,8 +126,12 @@ def __init__(self, scenario: str): self.max_episode_length = self._environment.episode_limit def reset(self) -> ResetReturn: - """Resets the env.""" - # Reset the environment + """Resets the environment. + + Returns: + ResetReturn: the initial observations and info. + """ + # Reset the underlying environment self._environment.reset() self._done = False @@ -140,7 +149,14 @@ def reset(self) -> ResetReturn: return observations, info def step(self, actions: Dict[str, np.ndarray]) -> StepReturn: - """Step in env.""" + """Steps the environment. + + Args: + actions (Dict[str, np.ndarray]): Actions taken by the agents. + + Returns: + StepReturn: the next observations, rewards, terminals, truncations, and info. + """ # Convert dict of actions to list for SMAC smac_actions = [] for agent in self.possible_agents: @@ -169,7 +185,11 @@ def step(self, actions: Dict[str, np.ndarray]) -> StepReturn: return observations, rewards, terminals, truncations, info def _get_legal_actions(self) -> List[np.ndarray]: - """Get legal actions from the environment.""" + """Gets the legal actions for each agent. + + Returns: + List[np.ndarray]: the legal actions for each agent. + """ legal_actions = [] for i, _ in enumerate(self.possible_agents): legal_actions.append( diff --git a/og_marl/environments/voltage_control.py b/og_marl/environments/voltage_control.py index b1450ac2..234f1c50 100644 --- a/og_marl/environments/voltage_control.py +++ b/og_marl/environments/voltage_control.py @@ -13,7 +13,7 @@ class VoltageControlEnv(BaseEnvironment): """Environment wrapper for MAPDN environment.""" def __init__(self) -> None: - """Constructor for VoltageControl.""" + """Constructor.""" self._environment = VoltageControl() self.possible_agents = [ f"agent_{agent_id}" for agent_id in range(self._environment.get_num_of_agents()) @@ -35,7 +35,12 @@ def __init__(self) -> None: } def reset(self) -> ResetReturn: - # Reset the environment + """Resets the environment. + + Returns: + ResetReturn: the initial observations and info. + """ + # Reset the underlying environment observations, state = self._environment.reset() # Global state @@ -52,7 +57,14 @@ def reset(self) -> ResetReturn: return observations, info def step(self, actions: Dict[str, np.ndarray]) -> StepReturn: - """Steps in env.""" + """Steps the environment. + + Args: + actions (Dict[str, np.ndarray]): Actions taken by the agents. + + Returns: + StepReturn: the next observations, rewards, terminals, truncations and info. + """ actions = self._preprocess_actions(actions) # Step the environment @@ -83,6 +95,14 @@ def step(self, actions: Dict[str, np.ndarray]) -> StepReturn: return next_observations, rewards, terminals, truncations, info def _preprocess_actions(self, actions: Dict[str, np.ndarray]) -> np.ndarray: + """TODO + + Args: + actions (Dict[str, np.ndarray]): _description_ + + Returns: + np.ndarray: _description_ + """ concat_action = [] for agent in self.possible_agents: concat_action.append(actions[agent]) @@ -90,6 +110,15 @@ def _preprocess_actions(self, actions: Dict[str, np.ndarray]) -> np.ndarray: return concat_action # type: ignore def _convert_observations(self, observations: List, done: bool) -> Observations: + """Converts the observations to a dictionary format.. + + Args: + observations (List): _description_ + done (bool): _description_ + + Returns: + Observations: _description_ + """ dict_observations = {} for i, agent in enumerate(self.possible_agents): obs = np.array(observations[i], "float32") diff --git a/og_marl/environments/wrappers.py b/og_marl/environments/wrappers.py index d832bf83..2cc67d7d 100644 --- a/og_marl/environments/wrappers.py +++ b/og_marl/environments/wrappers.py @@ -29,6 +29,14 @@ class ExperienceRecorder: def __init__( self, environment: BaseEnvironment, vault_name: str, write_to_vault_every: int = 10_000 ): + """Constructor for the ExperienceRecorder. + + Args: + environment (BaseEnvironment): environment that is being wrapped. + vault_name (str): name of the vault to write to. + write_to_vault_every (int, optional): + how often to write to the vault. Defaults to 10_000. + """ self._environment = environment self._buffer = fbx.make_flat_buffer( max_length=2 * 10_000, @@ -53,6 +61,19 @@ def _pack_timestep( truncations: Dict[str, np.ndarray], infos: Dict[str, Any], ) -> Dict[str, Any]: + """Packa an incoming timestep into a dictionary. + + Args: + observations (Dict[str, np.ndarray]): Observations from the environment. + actions (Dict[str, np.ndarray]): Actions taken by the agents. + rewards (Dict[str, np.ndarray]): Rewards received by the agents. + terminals (Dict[str, np.ndarray]): Whether the agents have terminated. + truncations (Dict[str, np.ndarray]): Whether the agents have been truncated. TODO + infos (Dict[str, Any]): Extra info from the environment. + + Returns: + Dict[str, Any]: _description_ + """ packed_timestep = { "observations": observations, "actions": actions, @@ -65,6 +86,11 @@ def _pack_timestep( return packed_timestep def reset(self) -> ResetReturn: + """Resets the environment. + + Returns: + ResetReturn: the initial observations and info. + """ observations, infos = self._environment.reset() self._observations = observations @@ -73,6 +99,14 @@ def reset(self) -> ResetReturn: return observations, infos def step(self, actions: Dict[str, np.ndarray]) -> StepReturn: + """Steps the environment. + + Args: + actions (Dict[str, np.ndarray]): Actions taken by the agents. + + Returns: + StepReturn: the next observations, rewards, terminals, truncations, and info. + """ observations, rewards, terminals, truncations, infos = self._environment.step(actions) packed_timestep = self._pack_timestep( @@ -113,7 +147,14 @@ def step(self, actions: Dict[str, np.ndarray]) -> StepReturn: return observations, rewards, terminals, truncations, infos def __getattr__(self, name: str) -> Any: - """Expose any other attributes of the underlying environment.""" + """Expose any other attributes of the underlying environment. + + Args: + name (str): name of the attribute. + + Returns: + Any: the attribute. + """ if hasattr(self.__class__, name): return self.__getattribute__(name) else: @@ -121,11 +162,24 @@ def __getattr__(self, name: str) -> Any: class Dtype: + """TODO""" + def __init__(self, environment: BaseEnvironment, dtype: str): + """Constructor for the Dtype wrapper. + + Args: + environment (BaseEnvironment): environment that is being wrapped. + dtype (str): data type to cast the observations to. + """ self._environment = environment self._dtype = dtype def reset(self) -> ResetReturn: + """Resets the environment. + + Returns: + ResetReturn: the initial observations and info. + """ observations = self._environment.reset() if isinstance(observations, tuple): @@ -139,6 +193,14 @@ def reset(self) -> ResetReturn: return observations, infos def step(self, actions: Dict[str, np.ndarray]) -> StepReturn: + """Steps the environment. + + Args: + actions (Dict[str, np.ndarray]): Actions taken by the agents. + + Returns: + StepReturn: the next observations, rewards, terminals, truncations, and info. + """ next_observations, rewards, terminals, truncations, infos = self._environment.step(actions) for agent, observation in next_observations.items(): @@ -147,7 +209,14 @@ def step(self, actions: Dict[str, np.ndarray]) -> StepReturn: return next_observations, rewards, terminals, truncations, infos def __getattr__(self, name: str) -> Any: - """Expose any other attributes of the underlying environment.""" + """Expose any other attributes of the underlying environment. + + Args: + name (str): name of the attribute. + + Returns: + Any: the attribute. + """ if hasattr(self.__class__, name): return self.__getattribute__(name) else: @@ -155,7 +224,14 @@ def __getattr__(self, name: str) -> Any: class PadObsandActs: + """TODO""" + def __init__(self, environment: BaseEnvironment): + """Constructor for the PadObsandActs wrapper. + + Args: + environment (BaseEnvironment): environment that is being wrapped. + """ self._environment = environment self._obs_dim = 0 @@ -172,6 +248,11 @@ def __init__(self, environment: BaseEnvironment): self._obs_dim = obs_dim def reset(self) -> ResetReturn: + """Resets the environment. + + Returns: + ResetReturn: the initial observations and info. + """ observations = self._environment.reset() if isinstance(observations, tuple): @@ -189,6 +270,14 @@ def reset(self) -> ResetReturn: return observations, infos def step(self, actions: Dict[str, np.ndarray]) -> StepReturn: + """Steps the environment. + + Args: + actions (Dict[str, np.ndarray]): Actions taken by the agents. + + Returns: + StepReturn: the next observations, rewards, terminals, truncations, and info. + """ actions = { agent: action[: self._environment.action_spaces[agent].shape[0]] for agent, action in actions.items() @@ -205,7 +294,14 @@ def step(self, actions: Dict[str, np.ndarray]) -> StepReturn: return next_observations, rewards, terminals, truncations, infos def __getattr__(self, name: str) -> Any: - """Expose any other attributes of the underlying environment.""" + """Expose any other attributes of the underlying environment. + + Args: + name (str): name of the attribute. + + Returns: + Any: the attribute. + """ if hasattr(self.__class__, name): return self.__getattribute__(name) else: