diff --git a/python/griddly/__init__.py b/python/griddly/__init__.py index 802eb6bb..23ec86ec 100644 --- a/python/griddly/__init__.py +++ b/python/griddly/__init__.py @@ -1,4 +1,5 @@ import os + import yaml from griddly.gym import GymWrapperFactory diff --git a/python/griddly/gym.py b/python/griddly/gym.py index cccef3f3..2ba3b12c 100644 --- a/python/griddly/gym.py +++ b/python/griddly/gym.py @@ -8,8 +8,8 @@ from gymnasium.envs.registration import register from gymnasium.spaces import Discrete, MultiDiscrete -from griddly.loader import GriddlyLoader from griddly import gd as gd +from griddly.loader import GriddlyLoader from griddly.spaces.action_space import MultiAgentActionSpace from griddly.spaces.observation_space import ( EntityObservationSpace, @@ -363,7 +363,7 @@ def enable_history(self, enable: bool = True) -> None: self._enable_history = enable self.game.enable_history(enable) - def step( # type: ignore + def step( # type: ignore self, action: Union[Action, List[Action]] ) -> Tuple[ Union[List[Observation], Observation], @@ -380,28 +380,50 @@ def step( # type: ignore ragged_actions = [] max_num_actions = 1 + try: + if self.player_count == 1: + ragged_actions.append( + np.array(action, dtype=np.int32).reshape( + -1, len(self.action_space_parts) + ) + ) + max_num_actions = ragged_actions[0].shape[0] + else: + for p in range(self.player_count): + a: Union[Action, List[Action]] + if isinstance(action, list): + if action[p] is None: + a = np.zeros(len(self.action_space_parts), dtype=np.int32) + else: + a = action[p] + else: + a = action + + ragged_actions.append( + np.array(a, dtype=np.int32).reshape( + -1, len(self.action_space_parts) + ) + ) - if self.player_count == 1: - ragged_actions.append(np.array(action, dtype=np.int32).reshape(-1, len(self.action_space_parts))) - max_num_actions = ragged_actions[0].shape[0] - else: - for p in range(self.player_count): - if isinstance(action, list): - ragged_actions.append(np.array(action[p], dtype=np.int32).reshape(-1, len(self.action_space_parts))) - else: - ragged_actions.append(np.array(action, dtype=np.int32).reshape(-1, len(self.action_space_parts))) - - if ragged_actions[p].shape[0] > max_num_actions: - max_num_actions = ragged_actions[p].shape[0] - - action_data = np.zeros((self.player_count, max_num_actions, len(self.action_space_parts)), dtype=np.int32) - - for p in range(self.player_count): - for i, a in enumerate(ragged_actions[p]): - action_data[p, i] = a + if ragged_actions[p].shape[0] > max_num_actions: + max_num_actions = ragged_actions[p].shape[0] + action_data = np.zeros( + (self.player_count, max_num_actions, len(self.action_space_parts)), + dtype=np.int32, + ) - reward, done, truncated, info = self.game.step_parallel(action_data) + for p in range(self.player_count): + for i, a in enumerate(ragged_actions[p]): + action_data[p, i] = a + + reward, done, truncated, info = self.game.step_parallel(action_data) + except Exception as e: + raise ValueError( + f"Invalid action {action} for action space {self.action_space}." \ + "Example valid action: {self.action_space.sample()}", + e + ) # Compatibility with gymnasium if self.player_count == 1: diff --git a/python/griddly/loader.py b/python/griddly/loader.py index c89585e0..98e7562b 100644 --- a/python/griddly/loader.py +++ b/python/griddly/loader.py @@ -5,6 +5,7 @@ from griddly import gd + class GriddlyLoader: def __init__(self) -> None: module_path = os.path.dirname(os.path.realpath(__file__)) @@ -43,4 +44,4 @@ def load_string(self, yaml_string: str) -> gd.GDY: def load_gdy(self, gdy_path: str) -> Dict[str, Any]: with open(self.get_full_path(gdy_path)) as gdy_file: - return yaml.load(gdy_file, Loader=yaml.SafeLoader) # type: ignore \ No newline at end of file + return yaml.load(gdy_file, Loader=yaml.SafeLoader) # type: ignore diff --git a/python/griddly/spaces/action_space.py b/python/griddly/spaces/action_space.py index 65b4dd58..87829612 100644 --- a/python/griddly/spaces/action_space.py +++ b/python/griddly/spaces/action_space.py @@ -3,9 +3,9 @@ from typing import TYPE_CHECKING, Any, List, Optional, Union import numpy as np -from gymnasium.spaces import Discrete, MultiDiscrete, Space +from gymnasium.spaces import Space -from griddly.typing import Action, ActionSpace +from griddly.typing import Action if TYPE_CHECKING: from griddly.gym import GymWrapper diff --git a/python/griddly/util/breakdown.py b/python/griddly/util/breakdown.py index 38c5ab28..36d1cbd5 100644 --- a/python/griddly/util/breakdown.py +++ b/python/griddly/util/breakdown.py @@ -6,14 +6,15 @@ import numpy.typing as npt import yaml -from griddly.loader import GriddlyLoader from griddly import gd +from griddly.loader import GriddlyLoader from griddly.util.vector_visualization import Vector2RGB class TemporaryEnvironment: """ - Because we have to load the game many different times with different configurations, this class makes sure we clean up objects we dont need + Because we have to load the game many different times with different configurations, + this class makes sure we clean up objects we dont need """ def __init__( @@ -120,7 +121,7 @@ def _populate_common_properties(self) -> None: for observer_name, config in object["Observers"].items(): self.supported_observers.add(observer_name) - self.observer_configs: Dict[str, Dict] = { + self.observer_configs: Dict[str, Dict] = { "Block2D": {}, "Sprite2D": {}, "Vector": {}, @@ -207,12 +208,12 @@ def _populate_levels(self) -> None: if self.observer_configs[observer_name]["TrackAvatar"]: continue - for l, level in self.levels.items(): - env.game.load_level(l) + for l_key, level in self.levels.items(): + env.game.load_level(l_key) env.game.reset() rendered_level = env.render_rgb() - self.levels[l]["Observers"][observer_name] = rendered_level - self.levels[l]["Size"] = [ + self.levels[l_key]["Observers"][observer_name] = rendered_level + self.levels[l_key]["Size"] = [ env.game.get_width(), env.game.get_height(), ] diff --git a/python/griddly/util/environment_generator_generator.py b/python/griddly/util/environment_generator_generator.py index 08e5f1d8..a9376595 100644 --- a/python/griddly/util/environment_generator_generator.py +++ b/python/griddly/util/environment_generator_generator.py @@ -1,10 +1,12 @@ import os -from typing import Optional, List, Any, Dict, Union +from typing import Any, Dict, List, Optional, Union + import gymnasium as gym import numpy as np import yaml + from griddly import gd -from griddly.gym import GymWrapperFactory, GymWrapper +from griddly.gym import GymWrapper, GymWrapperFactory class EnvironmentGeneratorGenerator: diff --git a/python/griddly/util/rllib/__init__.py b/python/griddly/util/rllib/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/python/griddly/util/rllib/callbacks.py b/python/griddly/util/rllib/callbacks.py deleted file mode 100644 index 9901166d..00000000 --- a/python/griddly/util/rllib/callbacks.py +++ /dev/null @@ -1,144 +0,0 @@ -from collections import Counter -from typing import Any, Dict, List, Optional, Union - -from ray.rllib import Policy, RolloutWorker -from ray.rllib.algorithms.callbacks import DefaultCallbacks -from ray.rllib.env import BaseEnv -from ray.rllib.evaluation import MultiAgentEpisode -from ray.rllib.utils.typing import PolicyID -from wandb import Video - - -class GriddlyRLLibCallbacks(DefaultCallbacks): - """Contains helper functions for Griddly callbacks""" - - def _get_player_ids(self, base_env: BaseEnv, env_index: int) -> List[Union[str, int]]: - envs = base_env.get_sub_environments() - player_count = envs[env_index].player_count - if player_count == 1: - return ["agent0"] - else: - return [p for p in range(1, player_count + 1)] - - -class VideoCallbacks(GriddlyRLLibCallbacks): - def on_episode_start( - self, - *, - worker: RolloutWorker, - base_env: BaseEnv, - policies: Dict[PolicyID, Policy], - episode: MultiAgentEpisode, - env_index: Optional[int] = None, - **kwargs: Dict[str, Any], - ) -> None: - envs = base_env.get_sub_environments() - envs[env_index].on_episode_start(worker.worker_index, env_index) - - def on_episode_end( - self, - *, - worker: RolloutWorker, - base_env: BaseEnv, - policies: Dict[PolicyID, Policy], - episode: MultiAgentEpisode, - env_index: Optional[int] = None, - **kwargs: Dict[str, Any], - ) -> None: - envs = base_env.get_sub_environments() - - for video in envs[env_index].videos: - level = video["level"] - path = video["path"] - episode.media[f"level_{level}"] = Video(path) - - envs[env_index].videos = [] - - -class ActionTrackerCallbacks(GriddlyRLLibCallbacks): - def __init__(self) -> None: - super().__init__() - - self._action_frequency_trackers: Dict[int, List[Counter]] = {} - - def on_episode_start( - self, - *, - worker: RolloutWorker, - base_env: BaseEnv, - policies: Dict[PolicyID, Policy], - episode: MultiAgentEpisode, - env_index: Optional[int] = None, - **kwargs: Dict[str, Any] - ) -> None: - self._action_frequency_trackers[episode.episode_id] = [] - assert env_index is not None, "Env index must be set" - for _ in self._get_player_ids(base_env, env_index): - self._action_frequency_trackers[episode.episode_id].append(Counter()) - - def on_episode_step( - self, - *, - worker: RolloutWorker, - base_env: BaseEnv, - episode: MultiAgentEpisode, - env_index: Optional[int] = None, - **kwargs: Dict[str, Any] - ) -> None: - assert env_index is not None, "Env index must be set" - for p, id in enumerate(self._get_player_ids(base_env, env_index)): - info = episode.last_info_for(id) - if "History" in info: - history = info["History"] - for event in history: - action_name = event["ActionName"] - self._action_frequency_trackers[episode.episode_id][p][ - action_name - ] += 1 - - def on_episode_end( - self, - *, - worker: RolloutWorker, - base_env: BaseEnv, - policies: Dict[PolicyID, Policy], - episode: MultiAgentEpisode, - env_index: Optional[int] = None, - **kwargs: Dict[str, Any] - ) -> None: - assert env_index is not None, "Env index must be set" - for p, id in enumerate(self._get_player_ids(base_env, env_index)): - for action_name, frequency in self._action_frequency_trackers[ - episode.episode_id - ][p].items(): - episode.custom_metrics[f"agent_info/{id}/{action_name}"] = frequency - - del self._action_frequency_trackers[episode.episode_id] - - -class WinLoseMetricCallbacks(GriddlyRLLibCallbacks): - def __init__(self) -> None: - super().__init__() - - def on_episode_end( - self, - *, - worker: RolloutWorker, - base_env: BaseEnv, - policies: Dict[PolicyID, Policy], - episode: MultiAgentEpisode, - env_index: Optional[int] = None, - **kwargs: Dict[str, Any] - ) -> None: - assert env_index is not None, "Env index must be set" - for p, id in enumerate(self._get_player_ids(base_env, env_index)): - info = episode.last_info_for(id) - episode.custom_metrics[f"agent_info/{id}/win"] = ( - 1 if info["PlayerResults"][f"{p + 1}"] == "Win" else 0 - ) - episode.custom_metrics[f"agent_info/{id}/lose"] = ( - 1 if info["PlayerResults"][f"{p + 1}"] == "Lose" else 0 - ) - episode.custom_metrics[f"agent_info/{id}/end"] = ( - 1 if info["PlayerResults"][f"{p + 1}"] == "End" else 0 - ) diff --git a/python/griddly/util/rllib/environment/__init__.py b/python/griddly/util/rllib/environment/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/python/griddly/util/rllib/environment/base.py b/python/griddly/util/rllib/environment/base.py deleted file mode 100644 index cf7fd2a2..00000000 --- a/python/griddly/util/rllib/environment/base.py +++ /dev/null @@ -1,92 +0,0 @@ -import os -from abc import ABC -from typing import Any, Dict, List, Optional, Union - -from griddly.gym import GymWrapper -from griddly.spaces.action_space import MultiAgentActionSpace -from griddly.spaces.observation_space import MultiAgentObservationSpace -from griddly.typing import ActionSpace, ObservationSpace -from griddly.util.rllib.environment.observer_episode_recorder import ( - ObserverEpisodeRecorder, -) - - -class _RLlibEnvCache: - def __init__(self) -> None: - self.reset() - - def reset(self) -> None: - self.action_space: Optional[Union[ActionSpace, MultiAgentActionSpace]] = None - self.observation_space: Optional[ - Union[ObservationSpace, MultiAgentObservationSpace] - ] = None - - -class _RLlibEnv(ABC): - def __init__(self, env_config: Dict[str, Any]) -> None: - self._rllib_cache = _RLlibEnvCache() - - self._env = GymWrapper(**env_config, reset=False) - - self.env_config = env_config - - self.env_steps = 0 - self._agent_recorders: Optional[ - Union[ObserverEpisodeRecorder, List[ObserverEpisodeRecorder]] - ] = None - self._global_recorder: Optional[ObserverEpisodeRecorder] = None - - self._env_idx: Optional[int] = None - self._worker_idx: Optional[int] = None - - self.video_initialized = False - - self.record_video_config = env_config.get("record_video_config", None) - - self.videos: List[Dict[str, Any]] = [] - - if self.record_video_config is not None: - self.video_frequency = self.record_video_config.get("frequency", 1000) - self.fps = self.record_video_config.get("fps", 10) - self.video_directory = os.path.realpath( - self.record_video_config.get("directory", ".") - ) - self.include_global_video = self.record_video_config.get( - "include_global", True - ) - self.include_agent_videos = self.record_video_config.get( - "include_agents", False - ) - os.makedirs(self.video_directory, exist_ok=True) - - self.record_actions = env_config.get("record_actions", False) - - self.generate_valid_action_trees = env_config.get( - "generate_valid_action_trees", False - ) - self._random_level_on_reset = env_config.get("random_level_on_reset", False) - level_generator_rllib_config = env_config.get("level_generator", None) - - self._level_generator = None - if level_generator_rllib_config is not None: - level_generator_class = level_generator_rllib_config["class"] - level_generator_config = level_generator_rllib_config["config"] - self._level_generator = level_generator_class(level_generator_config) - - self._env.enable_history(self.record_actions) - - @property - def width(self) -> int: - assert self._env.observation_space.shape is not None - return self._env.observation_space.shape[0] - - @property - def height(self) -> int: - assert self._env.observation_space.shape is not None - return self._env.observation_space.shape[1] - - def _get_valid_action_trees(self) -> Union[Dict[str, Any], List[Dict[str, Any]]]: - valid_action_trees = self._env.game.build_valid_action_trees() - if self._env.player_count == 1: - return valid_action_trees[0] - return valid_action_trees diff --git a/python/griddly/util/rllib/environment/level_generator.py b/python/griddly/util/rllib/environment/level_generator.py deleted file mode 100644 index 10072d1f..00000000 --- a/python/griddly/util/rllib/environment/level_generator.py +++ /dev/null @@ -1,6 +0,0 @@ -class LevelGenerator: - def __init__(self, config: dict) -> None: - self._config = config - - def generate(self) -> str: - raise NotImplementedError() diff --git a/python/griddly/util/rllib/environment/multi_agent.py b/python/griddly/util/rllib/environment/multi_agent.py deleted file mode 100644 index 22c4b027..00000000 --- a/python/griddly/util/rllib/environment/multi_agent.py +++ /dev/null @@ -1,174 +0,0 @@ -from collections import defaultdict -from typing import Any, Dict, List, Optional, Set, Tuple - -from ray.rllib.utils.typing import MultiAgentDict - -from griddly.typing import Action -from griddly.util.rllib.environment.base import _RLlibEnv -from griddly.util.rllib.environment.observer_episode_recorder import \ - ObserverEpisodeRecorder - - -class RLlibMultiAgentWrapper(_RLlibEnv): - def __init__(self, env_config: Dict[str, Any]) -> None: - super().__init__(env_config) - - self.reset() - - self._player_done_variable = env_config.get("player_done_variable", None) - - # Used to keep track of agents that are active in the environment - self._active_agents: Set[int] = set() - - assert ( - self._env.player_count > 1 - ), "RLlibMultiAgentWrapper can only be used with environments that have multiple agents" - - def _to_multi_agent_map(self, data: List[Any]) -> MultiAgentDict: - return {a: data[a - 1] for a in self._active_agents} - - def reset( - self, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None - ) -> Tuple[Dict[int, Any], Dict[Any, Any]]: - obs, info = self._env.reset(seed, options) - assert isinstance(obs, list), "RLlibMultiAgentWrapper expects a list of obs" - self._agent_ids = [a for a in range(1, self._env.player_count + 1)] - self._active_agents.update([a for a in range(1, self._env.player_count + 1)]) - return self._to_multi_agent_map(obs), info - - def _resolve_player_done_variable(self) -> MultiAgentDict: - resolved_variables = self._env.game.get_global_variables( - [self._player_done_variable] - ) - is_player_done = resolved_variables[self._player_done_variable] - assert isinstance( - is_player_done, Dict - ), "player_done_variable must be a global variable" - return is_player_done - - def _after_step( - self, - obs_map: MultiAgentDict, - reward_map: MultiAgentDict, - done_map: MultiAgentDict, - truncated_map: MultiAgentDict, - info_map: MultiAgentDict, - ) -> MultiAgentDict: - extra_info: MultiAgentDict = {} - - if self.is_video_enabled(): - videos_list = [] - if self.include_agent_videos: - for a in self._active_agents: - assert self._agent_recorders is not None and isinstance( - self._agent_recorders, list - ) - end_video = ( - done_map[a] - or done_map["__all__"] - or truncated_map[a] - or truncated_map["__all__"] - ) - video_info = self._agent_recorders[a].step( - self._env.level_id, self.env_steps, end_video - ) - if video_info is not None: - videos_list.append(video_info) - if self.include_global_video and self._global_recorder is not None: - end_video = done_map["__all__"] or truncated_map["__all__"] - video_info = self._global_recorder.step( - self._env.level_id, self.env_steps, end_video - ) - if video_info is not None: - videos_list.append(video_info) - - self.videos = videos_list - - return extra_info - - def step( - self, action_dict: MultiAgentDict - ) -> Tuple[ - MultiAgentDict, MultiAgentDict, MultiAgentDict, MultiAgentDict, MultiAgentDict - ]: - actions_array: List[Action] = [] * self._env.player_count - for agent_id, action in action_dict.items(): - actions_array[agent_id - 1] = action - - obs, reward, all_done, all_truncated, info = self._env.step(actions_array) - - done_map: Dict[str, bool] = {"__all__": all_done} - truncated_map: Dict[str, bool] = {"__all__": all_truncated} - - if self._player_done_variable is not None: - griddly_players_done = self._resolve_player_done_variable() - - for agent_id in self._active_agents: - done_map[agent_id] = griddly_players_done[agent_id] == 1 - truncated_map[ - agent_id - ] = False # TODO: not sure how to support multi-agent truncated? - else: - for p in range(1, self._env.player_count + 1): - done_map[str(p)] = False - - if self.generate_valid_action_trees: - info_map = self._to_multi_agent_map( - [ - {"valid_action_tree": valid_action_tree} - for valid_action_tree in info["valid_action_tree"] - ] - ) - else: - info_map = defaultdict(lambda: defaultdict(dict)) - - if self.record_actions: - for event in info["History"]: - event_player_id = event["PlayerId"] - if event_player_id != 0: - if "History" not in info_map[event_player_id]: - info_map[event_player_id]["History"] = [] - info_map[event_player_id]["History"].append(event) - - assert isinstance(obs, list), "RLlibMultiAgentWrapper expects a list of obs" - assert isinstance( - reward, list - ), "RLlibMultiAgentWrapper expects a list of rewards" - obs_map = self._to_multi_agent_map(obs) - reward_map = self._to_multi_agent_map(reward) - - # Finally remove any agent ids that are done - for agent_id, is_done in done_map.items(): - if is_done: - self._active_agents.discard(agent_id) - - self._after_step(obs_map, reward_map, done_map, truncated_map, info_map) - - return obs_map, reward_map, done_map, truncated_map, info_map - - def is_video_enabled(self) -> bool: - return ( - self.record_video_config is not None - and self._env_idx is not None - and self._env_idx == 0 - ) - - def on_episode_start(self, worker_idx: int, env_idx: int) -> None: - self._env_idx = env_idx - self._worker_idx = worker_idx - - if self.is_video_enabled() and not self.video_initialized: - self.init_video_recording() - self.video_initialized = True - - def init_video_recording(self) -> None: - if self.include_agent_videos: - assert isinstance(self._agent_recorders, list) - for agent_id in self._agent_ids: - self._agent_recorders[agent_id] = ObserverEpisodeRecorder( - self._env, agent_id - 1, self.video_frequency, self.video_directory - ) - if self.include_global_video: - self._global_recorder = ObserverEpisodeRecorder( - self._env, "global", self.video_frequency, self.video_directory - ) diff --git a/python/griddly/util/rllib/environment/observer_episode_recorder.py b/python/griddly/util/rllib/environment/observer_episode_recorder.py deleted file mode 100644 index 762fc271..00000000 --- a/python/griddly/util/rllib/environment/observer_episode_recorder.py +++ /dev/null @@ -1,69 +0,0 @@ -import os -from enum import Enum -from typing import Any, Dict, Union, Optional -from uuid import uuid1 - -from griddly.gym import GymWrapper -from griddly.util.render_tools import RenderToVideo -from griddly.wrappers.render_wrapper import RenderWrapper - -class RecordingState(Enum): - NOT_RECORDING = 1 - WAITING_FOR_EPISODE_START = 2 - BEFORE_RECORDING = 3 - RECORDING = 4 - - -class ObserverEpisodeRecorder: - def __init__( - self, - env: GymWrapper, - observer: Union[str, int], - video_frequency: int, - video_directory: str = ".", - fps: int = 10, - ) -> None: - self._video_frequency = video_frequency - self._video_directory = video_directory - self._observer = observer - self._env = RenderWrapper(env, observer, "rgb_array") - self._fps = fps - - self._recording_state = RecordingState.BEFORE_RECORDING - self._recorder: RenderToVideo - - def step(self, level_id: int, step_count: int, done: bool) -> Optional[Dict[str, Any]]: - video_info = None - - if ( - self._recording_state is RecordingState.NOT_RECORDING - and step_count % self._video_frequency == 0 - ): - self._recording_state = RecordingState.WAITING_FOR_EPISODE_START - - if self._recording_state == RecordingState.BEFORE_RECORDING: - video_filename = os.path.join( - self._video_directory, - f"episode_video_{self._observer}_{uuid1()}_{level_id}_{step_count}.mp4", - ) - - self._recorder = RenderToVideo(self._env, video_filename) - - self._recording_state = RecordingState.RECORDING - - if self._recording_state == RecordingState.RECORDING: - self._recorder.capture_frame() - if done: - self._recording_state = RecordingState.NOT_RECORDING - self._recorder.close() - - video_info = {"level": level_id, "path": self._recorder.path} - - if self._recording_state == RecordingState.WAITING_FOR_EPISODE_START: - if done: - self._recording_state = RecordingState.BEFORE_RECORDING - - return video_info - - def __del__(self) -> None: - self._recorder.close() diff --git a/python/griddly/util/rllib/environment/single_agent.py b/python/griddly/util/rllib/environment/single_agent.py deleted file mode 100644 index 2a396ea6..00000000 --- a/python/griddly/util/rllib/environment/single_agent.py +++ /dev/null @@ -1,169 +0,0 @@ -from typing import Any, Dict, List, Optional, Tuple, Union - -import numpy as np -import numpy.typing as npt - -from griddly.typing import Action, Observation -from griddly.util.rllib.environment.base import _RLlibEnv -from griddly.util.rllib.environment.observer_episode_recorder import \ - ObserverEpisodeRecorder - - -class RLlibEnv(_RLlibEnv): - """ - Wraps a Griddly environment for compatibility with RLLib. - - Use the `env_config` in the rllib config to provide Griddly Environment Parameters - - Example: - - Firstly register the RLlibWrapper using rllib's - - env_name = "my_env_name" - - register_env(env_name, RLlibWrapper) - - you can then configure it - - rllib_config = { - 'env_config': { - 'yaml_file': 'Single-Player/GVGAI/butterflies.yaml', - 'level": 6, - 'player_observer_type': gd.ObserverType.SPRITE_2D, - 'global_observer_type': gd.ObserverType.ISOMETRIC, - 'max_steps': 1000, - }, - # Other configuration options - } - - Create the rllib trainer using this config: - - trainer = ImpalaTrainer(rllib_config, env=env_name) - - """ - - def __init__(self, env_config: Dict[str, Any]) -> None: - super().__init__(env_config) - - self.reset() - - def _transform( - self, observation: Union[List[Observation], Observation] - ) -> Union[List[Observation], Observation]: - transformed_obs: Union[List[Observation], Observation] - if self._env.player_count > 1 and isinstance(observation, list): - transformed_obs = [] - for obs in observation: - assert isinstance( - obs, npt.NDArray - ), "When using RLLib, observations must be numpy arrays, such as VECTOR or SPRITE_2D" - transformed_obs.append(obs.transpose(1, 2, 0).astype(float)) - elif isinstance(observation, npt.NDArray): - transformed_obs = observation.transpose(1, 2, 0).astype(float) - else: - raise Exception( - f"Unsupported observation type {type(observation)} for {self.__class__.__name__}" - ) - - return transformed_obs - - def _after_step( - self, - observation: Union[List[Observation], Observation], - reward: Union[List[int], int], - done: bool, - info: Dict[str, Any], - ) -> Dict[str, Any]: - extra_info: Dict[str, Any] = {} - - if self.is_video_enabled(): - videos_list = [] - if self.include_agent_videos: - video_info = self._agent_recorder.step( - self._env.level_id, self.env_steps, done - ) - if video_info is not None: - videos_list.append(video_info) - if self.include_global_video and self._global_recorder is not None: - video_info = self._global_recorder.step( - self._env.level_id, self.env_steps, done - ) - if video_info is not None: - videos_list.append(video_info) - - self.videos = videos_list - - return extra_info - - def reset( - self, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None - ) -> Tuple[Union[List[Observation], Observation], Dict[Any, Any]]: - if options is None: - options = {} - if self._level_generator is not None: - options["level_string"] = self._level_generator.generate() - elif self._random_level_on_reset: - options["level_id"] = np.random.choice(self._env.level_count) - - self._rllib_cache.reset() - observation, info = self._env.reset(seed=seed, options=options) - - if self.generate_valid_action_trees: - self.last_valid_action_trees = self._get_valid_action_trees() - - return self._transform(observation), info - - def step( - self, action: Action - ) -> Tuple[ - Union[List[Observation], Observation], - Union[List[int], int], - bool, - bool, - Dict[Any, Any], - ]: - observation, reward, truncated, done, info = self._env.step(action) - - extra_info = self._after_step(observation, reward, done, info) - info.update(extra_info) - - if self.generate_valid_action_trees: - self.last_valid_action_trees = self._get_valid_action_trees() - info["valid_action_tree"] = self.last_valid_action_trees.copy() - - self.env_steps += 1 - - return self._transform(observation), reward, done, truncated, info - - def render(self) -> Union[str, npt.NDArray]: - return self._env.render() - - def is_video_enabled(self) -> bool: - return ( - self.record_video_config is not None - and self._env_idx is not None - and self._env_idx == 0 - ) - - def on_episode_start(self, worker_idx: int, env_idx: int) -> None: - self._env_idx = env_idx - self._worker_idx = worker_idx - - if self.is_video_enabled() and not self.video_initialized: - self.init_video_recording() - self.video_initialized = True - - def init_video_recording(self) -> None: - - if self.include_agent_videos: - self._agent_recorder = ObserverEpisodeRecorder( - self._env, 1, self.video_frequency, self.video_directory, self.fps - ) - if self.include_global_video: - self._global_recorder = ObserverEpisodeRecorder( - self._env, - "global", - self.video_frequency, - self.video_directory, - self.fps, - ) \ No newline at end of file diff --git a/python/griddly/wrappers/render_wrapper.py b/python/griddly/wrappers/render_wrapper.py index 94f814e7..fb855d88 100644 --- a/python/griddly/wrappers/render_wrapper.py +++ b/python/griddly/wrappers/render_wrapper.py @@ -12,7 +12,7 @@ class RenderWrapper(gym.Wrapper): env: GymWrapper - + def __init__( self, env: GymWrapper, observer: Union[str, int] = 0, render_mode: str = "human" ) -> None: diff --git a/python/griddly/wrappers/valid_action_space_wrapper.py b/python/griddly/wrappers/valid_action_space_wrapper.py index 9fd5bd2c..da6960f2 100644 --- a/python/griddly/wrappers/valid_action_space_wrapper.py +++ b/python/griddly/wrappers/valid_action_space_wrapper.py @@ -7,7 +7,7 @@ import numpy.typing as npt from griddly.gym import GymWrapper -from griddly.spaces.action_space import MultiAgentActionSpace, ValidatedActionSpace +from griddly.spaces.action_space import ValidatedActionSpace class ValidActionSpaceWrapper(gym.Wrapper): diff --git a/python/tests/entity_observer_test.py b/python/tests/entity_observer_test.py index 766f3268..94ef4101 100644 --- a/python/tests/entity_observer_test.py +++ b/python/tests/entity_observer_test.py @@ -32,7 +32,7 @@ def test_entity_observations(test_name): player_observer_type=gd.ObserverType.ENTITY, ) - obs, reward, done, truncated, info = env.step(0) + obs, reward, done, truncated, info = env.step([0, 0]) entities = obs["Entities"] entity_ids = obs["Ids"] obs["Locations"] @@ -104,7 +104,7 @@ def test_entity_observations_multi_agent(test_name): assert player_2_space["entity_2"] == ["x", "y", "z"] assert player_2_space["__global__"] == ["test_global_variable"] - obs, reward, done, truncated, info = env.step([0, 0]) + obs, reward, done, truncated, info = env.step([[0, 0, 0, 0], [0, 0, 0, 0]]) player_1_obs = obs[0] diff --git a/python/tests/example_mechanics/conditionals/conditionals.yaml b/python/tests/example_mechanics/conditionals/conditionals.yaml index a8330920..935706e4 100644 --- a/python/tests/example_mechanics/conditionals/conditionals.yaml +++ b/python/tests/example_mechanics/conditionals/conditionals.yaml @@ -5,8 +5,6 @@ Environment: This environment tests * Preconditions * If statements - Player: - AvatarObject: object # The player can only control a single avatar in the game Levels: - | w w w w w w w w w w w w w diff --git a/python/tests/named_observer_test.py b/python/tests/named_observer_test.py index 15859023..c9853516 100644 --- a/python/tests/named_observer_test.py +++ b/python/tests/named_observer_test.py @@ -36,7 +36,7 @@ def test_vector1(test_name): init_obs, init_info = env.reset() - obs, reward, done, truncated, info = env.step(0) + obs, reward, done, truncated, info = env.step([0, 0]) assert env.player_observation_space.shape == (1, 10, 10) assert obs.shape == (1, 10, 10) @@ -56,7 +56,7 @@ def test_vector2(test_name): init_obs, init_info = env.reset() - obs, reward, done, truncated, info = env.step(0) + obs, reward, done, truncated, info = env.step([0, 0]) assert env.player_observation_space.shape == (1, 5, 5) assert obs.shape == (1, 5, 5) @@ -76,7 +76,7 @@ def test_vector3(test_name): init_obs, init_info = env.reset() - obs, reward, done, truncated, info = env.step(0) + obs, reward, done, truncated, info = env.step([0, 0]) assert env.player_observation_space.shape == (1, 4, 4) assert obs.shape == (1, 4, 4) @@ -96,7 +96,7 @@ def test_multi_object_vector1(test_name): init_obs, init_info = env.reset() - obs, reward, done, truncated, info = env.step([0, 0]) + obs, reward, done, truncated, info = env.step([[0, 0], [0, 0]]) assert env.observation_space[0].shape == (1, 10, 10) assert env.observation_space[1].shape == (1, 10, 10) @@ -121,7 +121,7 @@ def test_multi_object_vector2(test_name): init_obs, init_info = env.reset() - obs, reward, done, truncated, info = env.step([0, 0]) + obs, reward, done, truncated, info = env.step([[0, 0], [0, 0]]) assert env.observation_space[0].shape == (1, 5, 5) assert env.observation_space[1].shape == (1, 5, 5) @@ -146,7 +146,7 @@ def test_multi_object_vector3(test_name): init_obs, init_info = env.reset() - obs, reward, done, truncated, info = env.step([0, 0]) + obs, reward, done, truncated, info = env.step([[0, 0], [0, 0]]) assert env.observation_space[0].shape == (1, 4, 4) assert env.observation_space[1].shape == (1, 4, 4) @@ -171,7 +171,7 @@ def test_multi_object_vector1_vector2(test_name): init_obs, init_info = env.reset() - obs, reward, done, truncated, info = env.step([0, 0]) + obs, reward, done, truncated, info = env.step([[0, 0], [0, 0]]) assert env.observation_space[0].shape == (1, 10, 10) assert env.observation_space[1].shape == (1, 5, 5) @@ -196,7 +196,7 @@ def test_multi_object_vector2_vector3(test_name): init_obs, init_info = env.reset() - obs, reward, done, truncated, info = env.step([0, 0]) + obs, reward, done, truncated, info = env.step([[0, 0], [0, 0]]) assert env.observation_space[0].shape == (1, 5, 5) assert env.observation_space[1].shape == (1, 4, 4) diff --git a/python/tests/partial_observability_test.py b/python/tests/partial_observability_test.py index 343db316..352e7682 100644 --- a/python/tests/partial_observability_test.py +++ b/python/tests/partial_observability_test.py @@ -40,7 +40,7 @@ def test_partial_observability_0_1(test_name): obs, reward, done, truncated, info = env.step([0, 0]) player1_obs = obs[0] - player2_obs = obs[1] + player2_obs = obs[1] assert env.player_observation_space[0].shape == (1, 3, 3) assert env.player_observation_space[1].shape == (1, 3, 3) diff --git a/python/tests/rllib_test.py b/python/tests/rllib_test.py deleted file mode 100644 index 64eb0d33..00000000 --- a/python/tests/rllib_test.py +++ /dev/null @@ -1,375 +0,0 @@ -import os -import shutil -import sys - -import pytest -import ray -from ray.rllib.algorithms.ppo import PPOConfig -from ray.rllib.models import ModelCatalog -from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 -from ray.tune import register_env, tune -from torch import nn - -from griddly import gd -from griddly.util.rllib.callbacks import VideoCallbacks -from griddly.util.rllib.environment.single_agent import RLlibEnv -from griddly.util.rllib.environment.multi_agent import RLlibMultiAgentWrapper - - -def count_videos(video_dir): - count = 0 - for path in os.listdir(video_dir): - # check if current path is a file - if os.path.isfile(os.path.join(video_dir, path)): - count += 1 - - return count - - -class SingleAgentFlatModel(TorchModelV2, nn.Module): - def __init__(self, obs_space, act_space, num_outputs, *args, **kwargs): - TorchModelV2.__init__(self, obs_space, act_space, num_outputs, *args, **kwargs) - nn.Module.__init__(self) - self.model = nn.Sequential( - nn.Flatten(), - (nn.Linear(468, 128)), - nn.ReLU(), - ) - self.policy_fn = nn.Linear(128, num_outputs) - self.value_fn = nn.Linear(128, 1) - - def forward(self, input_dict, state, seq_lens): - model_out = self.model(input_dict["obs"].permute(0, 3, 1, 2)) - self._value_out = self.value_fn(model_out) - return self.policy_fn(model_out), state - - def value_function(self): - return self._value_out.flatten() - - -@pytest.fixture -def test_name(request): - return request.node.name - - -@pytest.fixture(scope="module", autouse=True) -def ray_init(): - sep = os.pathsep - os.environ["PYTHONPATH"] = sep.join(sys.path) - ray.init(include_dashboard=False, num_cpus=1, num_gpus=0) - - -def test_rllib_single_player(test_name): - register_env(test_name, lambda config: RLlibEnv(config)) - ModelCatalog.register_custom_model("SingleAgentFlatModel", SingleAgentFlatModel) - - test_dir = f"./testdir/{test_name}" - - config = ( - PPOConfig() - .rollouts(num_rollout_workers=0, rollout_fragment_length=512) - .training( - model={"custom_model": "SingleAgentFlatModel"}, - train_batch_size=512, - lr=2e-5, - gamma=0.99, - lambda_=0.9, - use_gae=True, - clip_param=0.4, - grad_clip=None, - entropy_coeff=0.1, - vf_loss_coeff=0.25, - sgd_minibatch_size=64, - num_sgd_iter=10, - ) - .environment( - env_config={ - "global_observer_type": gd.ObserverType.VECTOR, - "player_observer_type": gd.ObserverType.VECTOR, - "yaml_file": "Single-Player/GVGAI/sokoban.yaml", - }, - env=test_name, - clip_actions=True, - ) - .debugging(log_level="ERROR") - .framework(framework="torch") - .resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0"))) - ) - - result = tune.run( - "PPO", - name="PPO", - stop={"timesteps_total": 100}, - storage_path=test_dir, - config=config.to_dict(), - ) - - assert result is not None - - shutil.rmtree(test_dir) - - -@pytest.mark.skip(reason="ffmpeg not installed on test server") -def test_rllib_single_player_record_videos(test_name): - sep = os.pathsep - os.environ["PYTHONPATH"] = sep.join(sys.path) - - register_env(test_name, lambda config: RLlibEnv(config)) - ModelCatalog.register_custom_model("SingleAgentFlatModel", SingleAgentFlatModel) - - test_dir = f"./testdir/{test_name}" - video_dir = "videos" - - config = ( - PPOConfig() - .rollouts(num_rollout_workers=0, rollout_fragment_length=64) - .callbacks(VideoCallbacks) - .training( - model={"custom_model": "SingleAgentFlatModel"}, - train_batch_size=64, - lr=2e-5, - gamma=0.99, - lambda_=0.9, - use_gae=True, - clip_param=0.4, - grad_clip=None, - entropy_coeff=0.1, - vf_loss_coeff=0.25, - sgd_minibatch_size=64, - num_sgd_iter=10, - ) - .environment( - env_config={ - "global_observer_type": gd.ObserverType.VECTOR, - "player_observer_type": gd.ObserverType.VECTOR, - "yaml_file": "Single-Player/GVGAI/sokoban.yaml", - "max_steps": 50, - "record_video_config": {"frequency": 100, "directory": video_dir}, - }, - env=test_name, - clip_actions=True, - ) - .debugging(log_level="ERROR") - .framework(framework="torch") - .resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0"))) - ) - - result = tune.run( - "PPO", - name="PPO", - stop={"timesteps_total": 512}, - storage_path=test_dir, - config=config.to_dict(), - ) - - assert result is not None - final_video_dir = os.path.join(result.trials[0].logdir, video_dir) - assert count_videos(final_video_dir) > 0 - - shutil.rmtree(test_dir) - - -class MultiAgentFlatModel(TorchModelV2, nn.Module): - def __init__(self, obs_space, act_space, num_outputs, *args, **kwargs): - TorchModelV2.__init__(self, obs_space, act_space, num_outputs, *args, **kwargs) - nn.Module.__init__(self) - self.model = nn.Sequential( - nn.Flatten(), - (nn.Linear(1458, 128)), - nn.ReLU(), - ) - self.policy_fn = nn.Linear(128, num_outputs) - self.value_fn = nn.Linear(128, 1) - - def forward(self, input_dict, state, seq_lens): - model_out = self.model(input_dict["obs"].permute(0, 3, 1, 2)) - self._value_out = self.value_fn(model_out) - return self.policy_fn(model_out), state - - def value_function(self): - return self._value_out.flatten() - - -@pytest.mark.skip(reason="flaky on github actions") -def test_rllib_multi_agent_self_play(test_name): - sep = os.pathsep - os.environ["PYTHONPATH"] = sep.join(sys.path) - - register_env( - test_name, lambda env_config: RLlibMultiAgentWrapper(RLlibEnv(env_config)) - ) - ModelCatalog.register_custom_model("MultiAgentFlatModel", MultiAgentFlatModel) - - test_dir = f"./testdir/{test_name}" - - config = ( - PPOConfig() - .rollouts(num_rollout_workers=0, rollout_fragment_length=64) - .training( - model={"custom_model": "MultiAgentFlatModel"}, - train_batch_size=64, - lr=2e-5, - gamma=0.99, - lambda_=0.9, - use_gae=True, - clip_param=0.4, - grad_clip=None, - entropy_coeff=0.1, - vf_loss_coeff=0.25, - sgd_minibatch_size=8, - num_sgd_iter=10, - ) - .environment( - env_config={ - "global_observer_type": gd.ObserverType.VECTOR, - "player_observer_type": gd.ObserverType.VECTOR, - "yaml_file": "Multi-Agent/robot_tag_12.yaml", - }, - env=test_name, - clip_actions=True, - ) - .debugging(log_level="ERROR") - .framework(framework="torch") - .resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0"))) - ) - - result = tune.run( - "PPO", - name="PPO", - stop={"timesteps_total": 512}, - storage_path=test_dir, - config=config.to_dict(), - ) - - assert result is not None - - shutil.rmtree(test_dir) - - -@pytest.mark.skip(reason="ffmpeg not installed on test server") -def test_rllib_multi_agent_self_play_record_videos(test_name): - sep = os.pathsep - os.environ["PYTHONPATH"] = sep.join(sys.path) - - register_env( - test_name, lambda env_config: RLlibMultiAgentWrapper(RLlibEnv(env_config)) - ) - ModelCatalog.register_custom_model("MultiAgentFlatModel", MultiAgentFlatModel) - - test_dir = f"./testdir/{test_name}" - video_dir = "videos" - - config = ( - PPOConfig() - .rollouts(num_rollout_workers=0, rollout_fragment_length=64) - .callbacks(VideoCallbacks) - .training( - model={"custom_model": "MultiAgentFlatModel"}, - train_batch_size=64, - lr=2e-5, - gamma=0.99, - lambda_=0.9, - use_gae=True, - clip_param=0.4, - grad_clip=None, - entropy_coeff=0.1, - vf_loss_coeff=0.25, - sgd_minibatch_size=8, - num_sgd_iter=10, - ) - .environment( - env_config={ - "global_observer_type": gd.ObserverType.SPRITE_2D, - "player_observer_type": gd.ObserverType.VECTOR, - "yaml_file": "Multi-Agent/robot_tag_12.yaml", - "record_video_config": {"frequency": 2, "directory": video_dir}, - }, - env=test_name, - clip_actions=True, - ) - .debugging(log_level="ERROR") - .framework(framework="torch") - .resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0"))) - ) - - result = tune.run( - "PPO", - name="PPO", - stop={"timesteps_total": 512}, - storage_path=test_dir, - config=config.to_dict(), - ) - - assert result is not None - final_video_dir = os.path.join(result.trials[0].logdir, video_dir) - assert count_videos(final_video_dir) > 0 - - shutil.rmtree(test_dir) - - -@pytest.mark.skip(reason="ffmpeg not installed on test server") -def test_rllib_multi_agent_self_play_record_videos_all_agents(test_name): - sep = os.pathsep - os.environ["PYTHONPATH"] = sep.join(sys.path) - - register_env( - test_name, lambda env_config: RLlibMultiAgentWrapper(RLlibEnv(env_config)) - ) - ModelCatalog.register_custom_model("MultiAgentFlatModel", MultiAgentFlatModel) - - test_dir = f"./testdir/{test_name}" - video_dir = "videos" - - config = ( - PPOConfig() - .rollouts(num_rollout_workers=0, rollout_fragment_length=64) - .callbacks(VideoCallbacks) - .training( - model={"custom_model": "MultiAgentFlatModel"}, - train_batch_size=64, - lr=2e-5, - gamma=0.99, - lambda_=0.9, - use_gae=True, - clip_param=0.4, - grad_clip=None, - entropy_coeff=0.1, - vf_loss_coeff=0.25, - sgd_minibatch_size=8, - num_sgd_iter=10, - ) - .environment( - env_config={ - "global_observer_type": gd.ObserverType.SPRITE_2D, - "player_observer_type": gd.ObserverType.VECTOR, - "yaml_file": "Multi-Agent/robot_tag_12.yaml", - "player_done_variable": "player_done", - "record_video_config": { - "frequency": 2, - "directory": video_dir, - "include_agents": True, - "include_global": True, - }, - "max_steps": 200, - }, - env=test_name, - clip_actions=True, - ) - .debugging(log_level="ERROR") - .framework(framework="torch") - .resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0"))) - ) - - result = tune.run( - "PPO", - name="PPO", - stop={"timesteps_total": 10000}, - storage_path=test_dir, - config=config.to_dict(), - ) - - assert result is not None - final_video_dir = os.path.join(result.trials[0].logdir, video_dir) - assert count_videos(final_video_dir) > 0 - - shutil.rmtree(test_dir)