Skip to content

Commit

Permalink
Merge pull request #43 from epignatelli/spaces
Browse files Browse the repository at this point in the history
Add Spaces
  • Loading branch information
epignatelli authored Jul 19, 2023
2 parents 5832d9d + 25c7603 commit 9b832b6
Show file tree
Hide file tree
Showing 10 changed files with 136 additions and 18 deletions.
1 change: 1 addition & 0 deletions navix/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,5 @@
environments,
terminations,
config,
spaces,
)
2 changes: 1 addition & 1 deletion navix/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,5 @@
# under the License.


__version__ = "0.3.7"
__version__ = "0.3.8"
__version_info__ = tuple(int(i) for i in __version__.split(".") if i.isdigit())
4 changes: 3 additions & 1 deletion navix/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,9 @@ def transparent(self) -> Array:

@property
def sprite(self) -> Array:
sprite = SPRITES_REGISTRY[Entities.DOOR.value][self.direction, jnp.asarray(self.open, dtype=jnp.int32)]
sprite = SPRITES_REGISTRY[Entities.DOOR.value][
self.direction, jnp.asarray(self.open, dtype=jnp.int32)
]
if sprite.ndim == 3:
# batch it
sprite = sprite[None]
Expand Down
37 changes: 36 additions & 1 deletion navix/environments/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from __future__ import annotations

import sys
import abc
from enum import IntEnum
from typing import Any, Callable, Dict
Expand All @@ -31,9 +32,10 @@


from .. import tasks, terminations, observations
from ..graphics import RenderingCache
from ..graphics import RenderingCache, TILE_SIZE
from ..entities import State
from ..actions import ACTIONS
from ..spaces import Space, Discrete, Continuous


class StepType(IntEnum):
Expand Down Expand Up @@ -77,6 +79,39 @@ class Environment(struct.PyTreeNode):
pytree_node=False, default=terminations.on_navigation_completion
)

@property
def observation_space(self) -> Space:
if self.observation_fn == observations.none:
return Continuous(shape=())
elif self.observation_fn == observations.categorical:
return Discrete(sys.maxsize, shape=(self.height, self.width))
elif self.observation_fn == observations.categorical_first_person:
radius = observations.RADIUS
return Discrete(sys.maxsize, shape=(radius + 1, radius * 2 + 1))
elif self.observation_fn == observations.rgb:
return Discrete(
256,
shape=(self.height * TILE_SIZE, self.width * TILE_SIZE, 3),
dtype=jnp.uint8,
)
elif self.observation_fn == observations.rgb_first_person:
radius = observations.RADIUS
return Discrete(
256,
shape=(radius * TILE_SIZE * 2 + 1, radius * TILE_SIZE * 2 + 1, 3),
dtype=jnp.uint8,
)
else:
raise NotImplementedError(
"Unknown observation space for observation function {}".format(
self.observation_fn
)
)

@property
def action_space(self) -> Space:
return Discrete(len(ACTIONS))

@abc.abstractmethod
def reset(self, key: KeyArray, cache: RenderingCache | None = None) -> Timestep:
raise NotImplementedError()
Expand Down
20 changes: 15 additions & 5 deletions navix/environments/keydoor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,12 @@
class KeyDoor(Environment):
def reset(self, key: KeyArray, cache: Union[RenderingCache, None] = None) -> Timestep: # type: ignore
# check minimum height and width
assert self.height > 3, f"Room height must be greater than 3, got {self.height} instead"
assert self.width > 4, f"Room width must be greater than 5, got {self.width} instead"
assert (
self.height > 3
), f"Room height must be greater than 3, got {self.height} instead"
assert (
self.width > 4
), f"Room width must be greater than 5, got {self.width} instead"

key, k1, k2, k3, k4 = jax.random.split(key, 5)

Expand All @@ -39,13 +43,19 @@ def reset(self, key: KeyArray, cache: Union[RenderingCache, None] = None) -> Tim
wall_cols = jnp.asarray([door_col] * (self.height - 2))
wall_pos = jnp.stack((wall_rows, wall_cols), axis=1)
# remove wall where the door is
wall_pos = jnp.delete(wall_pos, door_row - 1, axis=0, assume_unique_indices=True)
wall_pos = jnp.delete(
wall_pos, door_row - 1, axis=0, assume_unique_indices=True
)
walls = Wall(position=wall_pos)

# get rooms
first_room_mask = mask_by_coordinates(grid, (jnp.asarray(self.height), door_col), jnp.less)
first_room_mask = mask_by_coordinates(
grid, (jnp.asarray(self.height), door_col), jnp.less
)
first_room = jnp.where(first_room_mask, grid, -1) # put walls where not mask
second_room_mask = mask_by_coordinates(grid, (jnp.asarray(0), door_col), jnp.greater)
second_room_mask = mask_by_coordinates(
grid, (jnp.asarray(0), door_col), jnp.greater
)
second_room = jnp.where(second_room_mask, grid, -1) # put walls where not mask

# spawn player
Expand Down
3 changes: 1 addition & 2 deletions navix/environments/room.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,12 @@


from __future__ import annotations
from typing import Callable, Union
from typing import Union

import jax
import jax.numpy as jnp
from jax.random import KeyArray


from ..components import EMPTY_POCKET_ID
from ..entities import Entities, Goal, Player, State
from ..grid import random_positions, random_directions, room
Expand Down
2 changes: 1 addition & 1 deletion navix/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def crop(grid: Array, origin: Array, direction: Array, radius: int) -> Array:
cropped,
)

cropped = rotated[:radius + 1]
cropped = rotated[: radius + 1]
return jnp.asarray(cropped, dtype=grid.dtype)


Expand Down
13 changes: 7 additions & 6 deletions navix/observations.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@
from .grid import align, idx_from_coordinates, crop, view_cone


RADIUS = 3


def none(
state: State,
) -> Array:
Expand All @@ -52,7 +55,6 @@ def categorical(

def categorical_first_person(
state: State,
radius: int = 3,
) -> Array:
# get transparency map
transparency_map = jnp.where(state.grid == 0, 1, 0)
Expand All @@ -62,14 +64,14 @@ def categorical_first_person(

# apply view mask
player = state.get_player()
view = view_cone(transparency_map, player.position, radius)
view = view_cone(transparency_map, player.position, RADIUS)

# get categorical representation
tags = state.get_tags()
obs = state.grid.at[tuple(positions.T)].set(tags) * view

# crop grid to agent's view
obs = crop(obs, player.position, player.direction, radius)
obs = crop(obs, player.position, player.direction, RADIUS)

return obs

Expand Down Expand Up @@ -99,7 +101,6 @@ def rgb(

def rgb_first_person(
state: State,
radius: int = 3,
) -> Array:
# calculate final image size
image_size = (
Expand All @@ -113,7 +114,7 @@ def rgb_first_person(
transparent = state.get_transparency()
transparency_map = transparency_map.at[tuple(positions.T)].set(~transparent)
player = state.get_player()
view = view_cone(transparency_map, player.position, radius)
view = view_cone(transparency_map, player.position, RADIUS)
view = jax.image.resize(view, image_size, method="nearest")
view = jnp.tile(view[..., None], (1, 1, 3))

Expand All @@ -131,7 +132,7 @@ def rgb_first_person(
patchwork = patches.reshape(*state.grid.shape, *patches.shape[1:])

# crop grid to agent's view
patchwork = crop(patchwork, player.position, player.direction, radius)
patchwork = crop(patchwork, player.position, player.direction, RADIUS)

# reconstruct image
obs = jnp.swapaxes(patchwork, 1, 2)
Expand Down
70 changes: 70 additions & 0 deletions navix/spaces.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright [2023] The Helx Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from __future__ import annotations
from typing import Callable

import jax
import jax.numpy as jnp
from jax.random import KeyArray

from jax.core import Shape
from jax import Array
from jax.core import ShapedArray, Shape


POS_INF = jnp.asarray(1e16)
NEG_INF = jnp.asarray(-1e16)


class Space(ShapedArray):
minimum: Array = NEG_INF
maximum: Array = POS_INF

def __repr__(self):
return "{}, min={}, max={})".format(
super().__repr__()[:-1], self.minimum, self.maximum
)

def sample(self, key: KeyArray) -> Array:
raise NotImplementedError()


class Discrete(Space):
def __init__(self, n_elements: int, shape: Shape = (), dtype=jnp.int32):
super().__init__(shape, dtype)
self.minimum = jnp.asarray(0)
self.maximum = jnp.asarray(n_elements - 1)

def sample(self, key: KeyArray) -> Array:
return jax.random.randint(
key, self.shape, self.minimum, self.maximum, self.dtype
)


class Continuous(Space):
def __init__(self, shape: Shape = (), minimum=NEG_INF, maximum=POS_INF):
super().__init__(shape, jnp.float32)
self.minimum = minimum
self.maximum = maximum

def sample(self, key: KeyArray) -> Array:
assert jnp.issubdtype(self.dtype, jnp.floating)
# see: https://github.com/google/jax/issues/14003
lower = jnp.nan_to_num(self.minimum)
upper = upper = jnp.nan_to_num(self.maximum)
return jax.random.uniform(
key, self.shape, minval=lower, maxval=upper, dtype=self.dtype
)
2 changes: 1 addition & 1 deletion tests/test_environments.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,4 +83,4 @@ def test_keydoor2():
# test_room()
# jax.jit(test_room)()
# test_keydoor()
test_keydoor2()
test_keydoor2()

0 comments on commit 9b832b6

Please sign in to comment.