diff --git a/navix/__init__.py b/navix/__init__.py index 0aa49da..d809db3 100644 --- a/navix/__init__.py +++ b/navix/__init__.py @@ -29,4 +29,5 @@ environments, terminations, config, + spaces, ) diff --git a/navix/_version.py b/navix/_version.py index 011f80c..a4beb31 100644 --- a/navix/_version.py +++ b/navix/_version.py @@ -18,5 +18,5 @@ # under the License. -__version__ = "0.3.7" +__version__ = "0.3.8" __version_info__ = tuple(int(i) for i in __version__.split(".") if i.isdigit()) diff --git a/navix/entities.py b/navix/entities.py index 0b7b332..c667f49 100644 --- a/navix/entities.py +++ b/navix/entities.py @@ -241,7 +241,9 @@ def transparent(self) -> Array: @property def sprite(self) -> Array: - sprite = SPRITES_REGISTRY[Entities.DOOR.value][self.direction, jnp.asarray(self.open, dtype=jnp.int32)] + sprite = SPRITES_REGISTRY[Entities.DOOR.value][ + self.direction, jnp.asarray(self.open, dtype=jnp.int32) + ] if sprite.ndim == 3: # batch it sprite = sprite[None] diff --git a/navix/environments/environment.py b/navix/environments/environment.py index 0b2781b..225965d 100644 --- a/navix/environments/environment.py +++ b/navix/environments/environment.py @@ -20,6 +20,7 @@ from __future__ import annotations +import sys import abc from enum import IntEnum from typing import Any, Callable, Dict @@ -31,9 +32,10 @@ from .. import tasks, terminations, observations -from ..graphics import RenderingCache +from ..graphics import RenderingCache, TILE_SIZE from ..entities import State from ..actions import ACTIONS +from ..spaces import Space, Discrete, Continuous class StepType(IntEnum): @@ -77,6 +79,39 @@ class Environment(struct.PyTreeNode): pytree_node=False, default=terminations.on_navigation_completion ) + @property + def observation_space(self) -> Space: + if self.observation_fn == observations.none: + return Continuous(shape=()) + elif self.observation_fn == observations.categorical: + return Discrete(sys.maxsize, shape=(self.height, self.width)) + elif self.observation_fn == observations.categorical_first_person: + radius = observations.RADIUS + return Discrete(sys.maxsize, shape=(radius + 1, radius * 2 + 1)) + elif self.observation_fn == observations.rgb: + return Discrete( + 256, + shape=(self.height * TILE_SIZE, self.width * TILE_SIZE, 3), + dtype=jnp.uint8, + ) + elif self.observation_fn == observations.rgb_first_person: + radius = observations.RADIUS + return Discrete( + 256, + shape=(radius * TILE_SIZE * 2 + 1, radius * TILE_SIZE * 2 + 1, 3), + dtype=jnp.uint8, + ) + else: + raise NotImplementedError( + "Unknown observation space for observation function {}".format( + self.observation_fn + ) + ) + + @property + def action_space(self) -> Space: + return Discrete(len(ACTIONS)) + @abc.abstractmethod def reset(self, key: KeyArray, cache: RenderingCache | None = None) -> Timestep: raise NotImplementedError() diff --git a/navix/environments/keydoor.py b/navix/environments/keydoor.py index 6f8c8b0..cb3e628 100644 --- a/navix/environments/keydoor.py +++ b/navix/environments/keydoor.py @@ -14,8 +14,12 @@ class KeyDoor(Environment): def reset(self, key: KeyArray, cache: Union[RenderingCache, None] = None) -> Timestep: # type: ignore # check minimum height and width - assert self.height > 3, f"Room height must be greater than 3, got {self.height} instead" - assert self.width > 4, f"Room width must be greater than 5, got {self.width} instead" + assert ( + self.height > 3 + ), f"Room height must be greater than 3, got {self.height} instead" + assert ( + self.width > 4 + ), f"Room width must be greater than 5, got {self.width} instead" key, k1, k2, k3, k4 = jax.random.split(key, 5) @@ -39,13 +43,19 @@ def reset(self, key: KeyArray, cache: Union[RenderingCache, None] = None) -> Tim wall_cols = jnp.asarray([door_col] * (self.height - 2)) wall_pos = jnp.stack((wall_rows, wall_cols), axis=1) # remove wall where the door is - wall_pos = jnp.delete(wall_pos, door_row - 1, axis=0, assume_unique_indices=True) + wall_pos = jnp.delete( + wall_pos, door_row - 1, axis=0, assume_unique_indices=True + ) walls = Wall(position=wall_pos) # get rooms - first_room_mask = mask_by_coordinates(grid, (jnp.asarray(self.height), door_col), jnp.less) + first_room_mask = mask_by_coordinates( + grid, (jnp.asarray(self.height), door_col), jnp.less + ) first_room = jnp.where(first_room_mask, grid, -1) # put walls where not mask - second_room_mask = mask_by_coordinates(grid, (jnp.asarray(0), door_col), jnp.greater) + second_room_mask = mask_by_coordinates( + grid, (jnp.asarray(0), door_col), jnp.greater + ) second_room = jnp.where(second_room_mask, grid, -1) # put walls where not mask # spawn player diff --git a/navix/environments/room.py b/navix/environments/room.py index 4e6804f..de6918f 100644 --- a/navix/environments/room.py +++ b/navix/environments/room.py @@ -19,13 +19,12 @@ from __future__ import annotations -from typing import Callable, Union +from typing import Union import jax import jax.numpy as jnp from jax.random import KeyArray - from ..components import EMPTY_POCKET_ID from ..entities import Entities, Goal, Player, State from ..grid import random_positions, random_directions, room diff --git a/navix/grid.py b/navix/grid.py index 09c912c..eb8440e 100644 --- a/navix/grid.py +++ b/navix/grid.py @@ -181,7 +181,7 @@ def crop(grid: Array, origin: Array, direction: Array, radius: int) -> Array: cropped, ) - cropped = rotated[:radius + 1] + cropped = rotated[: radius + 1] return jnp.asarray(cropped, dtype=grid.dtype) diff --git a/navix/observations.py b/navix/observations.py index 99221e5..2ff6d0e 100644 --- a/navix/observations.py +++ b/navix/observations.py @@ -30,6 +30,9 @@ from .grid import align, idx_from_coordinates, crop, view_cone +RADIUS = 3 + + def none( state: State, ) -> Array: @@ -52,7 +55,6 @@ def categorical( def categorical_first_person( state: State, - radius: int = 3, ) -> Array: # get transparency map transparency_map = jnp.where(state.grid == 0, 1, 0) @@ -62,14 +64,14 @@ def categorical_first_person( # apply view mask player = state.get_player() - view = view_cone(transparency_map, player.position, radius) + view = view_cone(transparency_map, player.position, RADIUS) # get categorical representation tags = state.get_tags() obs = state.grid.at[tuple(positions.T)].set(tags) * view # crop grid to agent's view - obs = crop(obs, player.position, player.direction, radius) + obs = crop(obs, player.position, player.direction, RADIUS) return obs @@ -99,7 +101,6 @@ def rgb( def rgb_first_person( state: State, - radius: int = 3, ) -> Array: # calculate final image size image_size = ( @@ -113,7 +114,7 @@ def rgb_first_person( transparent = state.get_transparency() transparency_map = transparency_map.at[tuple(positions.T)].set(~transparent) player = state.get_player() - view = view_cone(transparency_map, player.position, radius) + view = view_cone(transparency_map, player.position, RADIUS) view = jax.image.resize(view, image_size, method="nearest") view = jnp.tile(view[..., None], (1, 1, 3)) @@ -131,7 +132,7 @@ def rgb_first_person( patchwork = patches.reshape(*state.grid.shape, *patches.shape[1:]) # crop grid to agent's view - patchwork = crop(patchwork, player.position, player.direction, radius) + patchwork = crop(patchwork, player.position, player.direction, RADIUS) # reconstruct image obs = jnp.swapaxes(patchwork, 1, 2) diff --git a/navix/spaces.py b/navix/spaces.py new file mode 100644 index 0000000..66e0011 --- /dev/null +++ b/navix/spaces.py @@ -0,0 +1,70 @@ +# Copyright [2023] The Helx Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from __future__ import annotations +from typing import Callable + +import jax +import jax.numpy as jnp +from jax.random import KeyArray + +from jax.core import Shape +from jax import Array +from jax.core import ShapedArray, Shape + + +POS_INF = jnp.asarray(1e16) +NEG_INF = jnp.asarray(-1e16) + + +class Space(ShapedArray): + minimum: Array = NEG_INF + maximum: Array = POS_INF + + def __repr__(self): + return "{}, min={}, max={})".format( + super().__repr__()[:-1], self.minimum, self.maximum + ) + + def sample(self, key: KeyArray) -> Array: + raise NotImplementedError() + + +class Discrete(Space): + def __init__(self, n_elements: int, shape: Shape = (), dtype=jnp.int32): + super().__init__(shape, dtype) + self.minimum = jnp.asarray(0) + self.maximum = jnp.asarray(n_elements - 1) + + def sample(self, key: KeyArray) -> Array: + return jax.random.randint( + key, self.shape, self.minimum, self.maximum, self.dtype + ) + + +class Continuous(Space): + def __init__(self, shape: Shape = (), minimum=NEG_INF, maximum=POS_INF): + super().__init__(shape, jnp.float32) + self.minimum = minimum + self.maximum = maximum + + def sample(self, key: KeyArray) -> Array: + assert jnp.issubdtype(self.dtype, jnp.floating) + # see: https://github.com/google/jax/issues/14003 + lower = jnp.nan_to_num(self.minimum) + upper = upper = jnp.nan_to_num(self.maximum) + return jax.random.uniform( + key, self.shape, minval=lower, maxval=upper, dtype=self.dtype + ) diff --git a/tests/test_environments.py b/tests/test_environments.py index 3d2c06c..48d6bf6 100644 --- a/tests/test_environments.py +++ b/tests/test_environments.py @@ -83,4 +83,4 @@ def test_keydoor2(): # test_room() # jax.jit(test_room)() # test_keydoor() - test_keydoor2() \ No newline at end of file + test_keydoor2()