Skip to content

Commit

Permalink
fixing a bunch of tests and removing rllib, rllib will be supported u…
Browse files Browse the repository at this point in the history
…sing gymnasium and petting zoo in future
  • Loading branch information
Bam4d committed Oct 18, 2023
1 parent 88d1be2 commit 2600b01
Show file tree
Hide file tree
Showing 21 changed files with 73 additions and 1,077 deletions.
1 change: 1 addition & 0 deletions python/griddly/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os

import yaml

from griddly.gym import GymWrapperFactory
Expand Down
64 changes: 43 additions & 21 deletions python/griddly/gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from gymnasium.envs.registration import register
from gymnasium.spaces import Discrete, MultiDiscrete

from griddly.loader import GriddlyLoader
from griddly import gd as gd
from griddly.loader import GriddlyLoader
from griddly.spaces.action_space import MultiAgentActionSpace
from griddly.spaces.observation_space import (
EntityObservationSpace,
Expand Down Expand Up @@ -363,7 +363,7 @@ def enable_history(self, enable: bool = True) -> None:
self._enable_history = enable
self.game.enable_history(enable)

def step( # type: ignore
def step( # type: ignore
self, action: Union[Action, List[Action]]
) -> Tuple[
Union[List[Observation], Observation],
Expand All @@ -380,28 +380,50 @@ def step( # type: ignore

ragged_actions = []
max_num_actions = 1
try:
if self.player_count == 1:
ragged_actions.append(
np.array(action, dtype=np.int32).reshape(
-1, len(self.action_space_parts)
)
)
max_num_actions = ragged_actions[0].shape[0]
else:
for p in range(self.player_count):
a: Union[Action, List[Action]]
if isinstance(action, list):
if action[p] is None:
a = np.zeros(len(self.action_space_parts), dtype=np.int32)
else:
a = action[p]
else:
a = action

ragged_actions.append(
np.array(a, dtype=np.int32).reshape(
-1, len(self.action_space_parts)
)
)

if self.player_count == 1:
ragged_actions.append(np.array(action, dtype=np.int32).reshape(-1, len(self.action_space_parts)))
max_num_actions = ragged_actions[0].shape[0]
else:
for p in range(self.player_count):
if isinstance(action, list):
ragged_actions.append(np.array(action[p], dtype=np.int32).reshape(-1, len(self.action_space_parts)))
else:
ragged_actions.append(np.array(action, dtype=np.int32).reshape(-1, len(self.action_space_parts)))

if ragged_actions[p].shape[0] > max_num_actions:
max_num_actions = ragged_actions[p].shape[0]

action_data = np.zeros((self.player_count, max_num_actions, len(self.action_space_parts)), dtype=np.int32)

for p in range(self.player_count):
for i, a in enumerate(ragged_actions[p]):
action_data[p, i] = a
if ragged_actions[p].shape[0] > max_num_actions:
max_num_actions = ragged_actions[p].shape[0]

action_data = np.zeros(
(self.player_count, max_num_actions, len(self.action_space_parts)),
dtype=np.int32,
)

reward, done, truncated, info = self.game.step_parallel(action_data)
for p in range(self.player_count):
for i, a in enumerate(ragged_actions[p]):
action_data[p, i] = a

reward, done, truncated, info = self.game.step_parallel(action_data)
except Exception as e:
raise ValueError(
f"Invalid action {action} for action space {self.action_space}." \
"Example valid action: {self.action_space.sample()}",
e
)

# Compatibility with gymnasium
if self.player_count == 1:
Expand Down
3 changes: 2 additions & 1 deletion python/griddly/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from griddly import gd


class GriddlyLoader:
def __init__(self) -> None:
module_path = os.path.dirname(os.path.realpath(__file__))
Expand Down Expand Up @@ -43,4 +44,4 @@ def load_string(self, yaml_string: str) -> gd.GDY:

def load_gdy(self, gdy_path: str) -> Dict[str, Any]:
with open(self.get_full_path(gdy_path)) as gdy_file:
return yaml.load(gdy_file, Loader=yaml.SafeLoader) # type: ignore
return yaml.load(gdy_file, Loader=yaml.SafeLoader) # type: ignore
4 changes: 2 additions & 2 deletions python/griddly/spaces/action_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from typing import TYPE_CHECKING, Any, List, Optional, Union

import numpy as np
from gymnasium.spaces import Discrete, MultiDiscrete, Space
from gymnasium.spaces import Space

from griddly.typing import Action, ActionSpace
from griddly.typing import Action

if TYPE_CHECKING:
from griddly.gym import GymWrapper
Expand Down
15 changes: 8 additions & 7 deletions python/griddly/util/breakdown.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@
import numpy.typing as npt
import yaml

from griddly.loader import GriddlyLoader
from griddly import gd
from griddly.loader import GriddlyLoader
from griddly.util.vector_visualization import Vector2RGB


class TemporaryEnvironment:
"""
Because we have to load the game many different times with different configurations, this class makes sure we clean up objects we dont need
Because we have to load the game many different times with different configurations,
this class makes sure we clean up objects we dont need
"""

def __init__(
Expand Down Expand Up @@ -120,7 +121,7 @@ def _populate_common_properties(self) -> None:
for observer_name, config in object["Observers"].items():
self.supported_observers.add(observer_name)

self.observer_configs: Dict[str, Dict] = {
self.observer_configs: Dict[str, Dict] = {
"Block2D": {},
"Sprite2D": {},
"Vector": {},
Expand Down Expand Up @@ -207,12 +208,12 @@ def _populate_levels(self) -> None:
if self.observer_configs[observer_name]["TrackAvatar"]:
continue

for l, level in self.levels.items():
env.game.load_level(l)
for l_key, level in self.levels.items():
env.game.load_level(l_key)
env.game.reset()
rendered_level = env.render_rgb()
self.levels[l]["Observers"][observer_name] = rendered_level
self.levels[l]["Size"] = [
self.levels[l_key]["Observers"][observer_name] = rendered_level
self.levels[l_key]["Size"] = [
env.game.get_width(),
env.game.get_height(),
]
6 changes: 4 additions & 2 deletions python/griddly/util/environment_generator_generator.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import os
from typing import Optional, List, Any, Dict, Union
from typing import Any, Dict, List, Optional, Union

import gymnasium as gym
import numpy as np
import yaml

from griddly import gd
from griddly.gym import GymWrapperFactory, GymWrapper
from griddly.gym import GymWrapper, GymWrapperFactory


class EnvironmentGeneratorGenerator:
Expand Down
Empty file.
144 changes: 0 additions & 144 deletions python/griddly/util/rllib/callbacks.py

This file was deleted.

Empty file.
Loading

0 comments on commit 2600b01

Please sign in to comment.