Skip to content

Commit

Permalink
Merge pull request #32 from epignatelli/sprits/opendoor
Browse files Browse the repository at this point in the history
Sprits/opendoor
  • Loading branch information
epignatelli authored Jun 27, 2023
2 parents eb3f676 + 2444e5f commit 2dcf7cc
Show file tree
Hide file tree
Showing 22 changed files with 677 additions and 359 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# setuptools_scm
helx/version.py

# developing
playground.ipynb

# vscode
.vscode/
wandb/
Expand Down
4 changes: 0 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
**[Quickstart](#what-is-navix)** | **[Installation](#installation)** | **[Examples](#examples)** | **[Cite](#cite)**

## What is NAVIX?
---
NAVIX is [minigrid](https://github.com/Farama-Foundation/Minigrid) in JAX, **>1000x** faster with Autograd and XLA support.
You can see a superficial performance comparison [here](docs/profiling.ipynb).

Expand All @@ -17,7 +16,6 @@ If you want join the development and contribute, please [open a discussion](http


## Installation
---
We currently support the OSs supported by JAX.
You can find a description [here](https://github.com/google/jax#installation).

Expand All @@ -40,7 +38,6 @@ pip install git+https://github.com/epignatelli/navix
```

## Examples
---

### XLA compilation
One straightforward use case is to accelerate the computation of the environment with XLA compilation.
Expand Down Expand Up @@ -76,7 +73,6 @@ TODO(epignatelli): add example.


## Cite
---
If you use `helx` please consider citing it as:

```bibtex
Expand Down
1 change: 1 addition & 0 deletions navix/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from . import (
actions,
components,
entities,
graphics,
grid,
observations,
Expand Down
2 changes: 1 addition & 1 deletion navix/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,5 @@
# under the License.


__version__ = "0.3.2"
__version__ = "0.3.3"
__version_info__ = tuple(int(i) for i in __version__.split(".") if i.isdigit())
79 changes: 36 additions & 43 deletions navix/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,41 +24,41 @@
import jax.numpy as jnp
from jax import Array

from .components import Door, Key, State, DISCARD_PILE_COORDS
from .grid import translate, rotate
from .entities import Door, Key, State
from .components import DISCARD_PILE_COORDS
from .grid import translate, rotate, positions_equal


DIRECTIONS = {0: "east", 1: "south", 2: "west", 3: "north"}


def _rotate(state: State, spin: int) -> State:
direction = rotate(state.player.direction, spin)
player = state.player.replace(direction=direction)
return state.replace(player=player)
direction = rotate(state.players.direction, spin)
player = state.players.replace(direction=direction)
return state.replace(players=player)


def _walkable(state: State, position: Array) -> Array:
# according to the grid
walkable = jnp.equal(state.grid[tuple(position)], 0)

# and not occupied by another non-walkable entity
occupied_keys = jax.vmap(lambda x: jnp.array_equal(x, position))(
state.keys.position
)
occupied_doors = jax.vmap(lambda x: jnp.array_equal(x, position))(
state.doors.position
)
occupied = jnp.any(jnp.concatenate([occupied_keys, occupied_doors]))
occupied_keys = positions_equal(position, state.keys.position)
# occupied by a door, and door is not open
occupied_doors = positions_equal(position, state.doors.position)
occupied_doors = occupied_doors & ~state.doors.open

occupied = jnp.any(jnp.logical_or(occupied_keys, occupied_doors))
# return: if walkable and not occupied
return jnp.logical_and(walkable, jnp.logical_not(occupied))


def _move(state: State, direction: Array) -> State:
new_position = translate(state.player.position, direction)
new_position = translate(state.players.position, direction)
can_move = _walkable(state, new_position)
new_position = jnp.where(can_move, new_position, state.player.position)
player = state.player.replace(position=new_position)
return state.replace(player=player)
new_position = jnp.where(can_move, new_position, state.players.position)
player = state.players.replace(position=new_position)
return state.replace(players=player)


def undefined(state: State) -> State:
Expand All @@ -85,67 +85,60 @@ def rotate_ccw(state: State) -> State:


def forward(state: State) -> State:
return _move(state, state.player.direction)
return _move(state, state.players.direction)


def right(state: State) -> State:
return _move(state, state.player.direction + 1)
return _move(state, state.players.direction + 1)


def backward(state: State) -> State:
return _move(state, state.player.direction + 2)
return _move(state, state.players.direction + 2)


def left(state: State) -> State:
return _move(state, state.player.direction + 3)


def _one_many_position_equal(a: Array, b: Array) -> Array:
assert a.ndim == 1 and b.ndim == 2
is_equal = jnp.sum(a[None] - b, axis=-1) == 0
assert is_equal.shape == (b.shape[0],)
return is_equal
return _move(state, state.players.direction + 3)


def pickup(state: State) -> State:
position_in_front = translate(state.player.position, state.player.direction)
position_in_front = translate(state.players.position, state.players.direction)

key_found = _one_many_position_equal(position_in_front, state.keys.position)
key_found = positions_equal(position_in_front, state.keys.position)

# update keys
positions = jnp.where(key_found, DISCARD_PILE_COORDS, state.keys.position)
keys = state.keys.replace(position=positions)

# update player's pocket, if the pocket has something else, we overwrite it
key = jnp.sum(state.keys.id * key_found, dtype=jnp.int32)
player = jax.lax.cond(jnp.any(key_found), lambda: state.player.replace(pocket=key), lambda: state.player)
player = jax.lax.cond(jnp.any(key_found), lambda: state.players.replace(pocket=key), lambda: state.players)

return state.replace(player=player, keys=keys)
return state.replace(players=player, keys=keys)


def open(state: State) -> State:
"""Unlocks and opens an openable object (like a door) if possible"""
# get the tile in front of the player
position_in_front = translate(state.player.position, state.player.direction)
position_in_front = translate(state.players.position, state.players.direction)

# check if there is a door in front of the player
door_found = position_in_front[None] == state.doors.position
door_found = positions_equal(position_in_front, state.doors.position)

# and that, if so, either it does not require a key or the player has the key
requires_key = state.doors.requires != -1
key_match = state.player.pocket == state.doors.requires
key_match = state.players.pocket == state.doors.requires
can_open = door_found & (key_match | ~requires_key )

# update doors
# TODO(epignatelli): in the future we want to mark the door as open, instead
# and have a different rendering for it
# if the door can be opened, move it to the discard pile
new_positions = jnp.where(can_open, DISCARD_PILE_COORDS, state.doors.position)
doors = state.doors.replace(position=new_positions)
# update doors if closed and can_open
do_open = (~state.doors.open & can_open)
open = jnp.where(do_open, True, state.doors.open)
doors = state.doors.replace(open=open)

# remove key from player's pocket
pocket = jnp.asarray(state.player.pocket * jnp.any(can_open), dtype=jnp.int32)
player = jax.lax.cond(jnp.any(can_open), lambda: state.player.replace(pocket=pocket), lambda: state.player)
pocket = jnp.asarray(state.players.pocket * jnp.any(can_open), dtype=jnp.int32)
player = jax.lax.cond(jnp.any(can_open), lambda: state.players.replace(pocket=pocket), lambda: state.players)

return state.replace(player=player, doors=doors)
return state.replace(players=player, doors=doors)


# TODO(epignatelli): a mutable dictionary here is dangerous
Expand Down
145 changes: 37 additions & 108 deletions navix/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,137 +19,66 @@


from __future__ import annotations
from typing import Dict
from enum import IntEnum

from jax import Array
from flax import struct
from jax.random import KeyArray
import jax.numpy as jnp

from .graphics import RenderingCache


DISCARD_PILE_COORDS = jnp.asarray((0, -1), dtype=jnp.int32)
DISCARD_PILE_IDX = jnp.asarray(-1, dtype=jnp.int32)
EMPTY_POCKET_ID = jnp.asarray(-1, dtype=jnp.int32)
UNSET_DIRECTION = jnp.asarray(-1, dtype=jnp.int32)
UNSET_CONSUMED = jnp.asarray(-1, dtype=jnp.int32)


class Component(struct.PyTreeNode):
"""A component is a part of the state of the environment."""
class EntityType(IntEnum):
WALL = 0
FLOOR = 1
PLAYER = 2
GOAL = 3
KEY = 4
DOOR = 5


class Player(Component):
"""Players are entities that can act around the environment"""
class Component(struct.PyTreeNode):
entity_type: Array = jnp.asarray(0, dtype=jnp.int32)
"""The type of the entity, 0 = player, 1 = goal, 2 = key, 3 = door"""

position: Array = DISCARD_PILE_COORDS # IntArray['b 2']

class Positionable(struct.PyTreeNode):
position: Array = DISCARD_PILE_COORDS
"""The (row, column) position of the entity in the grid, defaults to the discard pile (-1, -1)"""
# TODO(epignatelli): consider batching player over the number of players
# to allow tranposing the entities pytree for faster computation
# and to prepare the ground for multi-agent environments
tag: Array = jnp.asarray(1) # IntArray['2']
"""The tag of the component, used to identify the type of the component in `oobservations.categorical`"""
# we mark direction as static because it is convenient for mapping observations (e.g. jnp.rot90(grid, k=direction)
# however, this is feasible because we only have 4 direcitions
# will it scale in multi-agent settings?
direction: Array = jnp.asarray(0, dtype=jnp.int32) # IntArray['2']
"""The direction the entity: 0 = east, 1 = south, 2 = west, 3 = north"""
pocket: Array = EMPTY_POCKET_ID # IntArray['2']
"""The id of the item in the pocket (0 if empty)"""


class Goal(Component):
"""Goals are entities that can be reached by the player"""
class Directional(struct.PyTreeNode):
direction: Array = jnp.asarray(0, dtype=jnp.int32)
"""The direction the entity: 0 = east, 1 = south, 2 = west, 3 = north"""

position: Array = DISCARD_PILE_COORDS[None] # IntArray['b 2']
"""The (row, column) position of the entity in the grid, defaults to the discard pile (-1, -1)"""
tag: Array = jnp.ones((1,), dtype=jnp.int32) + 1 # IntArray['b']

class HasTag(struct.PyTreeNode):
tag: Array = jnp.asarray(0, dtype=jnp.int32)
"""The tag of the component, used to identify the type of the component in `oobservations.categorical`"""
probability: Array = jnp.ones((1,), dtype=jnp.float32) # FloatArray['b']
"""The probability of receiving the reward, if reached."""


class Key(Component):
"""Pickable items are world objects that can be picked up by the player.
Examples of pickable items are keys, coins, etc."""
class Stochastic(struct.PyTreeNode):
probability: Array = jnp.asarray(1.0, dtype=jnp.float32)
"""The probability of receiving the reward, if reached."""

position: Array = DISCARD_PILE_COORDS[None] # IntArray['b 2']
"""The (row, column) position of the entity in the grid, defaults to the discard pile (-1, -1)"""
id: Array = jnp.ones((1,), dtype=jnp.int32) # IntArray['b']
"""The id of the item. If set, it must be >= 1."""

@property
def tag(self):
return -self.id
class Openable(struct.PyTreeNode):
requires: Array = EMPTY_POCKET_ID
"""The id of the item required to consume this item. If set, it must be >= 1."""
open: Array = jnp.asarray(False, dtype=jnp.bool_)
"""Whether the item is open or not."""


class Pickable(struct.PyTreeNode):
id: Array = jnp.asarray(1, dtype=jnp.int32)
"""The id of the item. If set, it must be >= 1."""

class Door(Component):
"""Consumable items are world objects that can be consumed by the player.
Consuming an item requires a tool (e.g. a key to open a door).
A tool is an id (int) of another item, specified in the `requires` field (-1 if no tool is required).
After an item is consumed, it is both removed from the `state.entities` collection, and replaced in the grid
by the item specified in the `replacement` field (0 = floor by default).
Examples of consumables are doors (to open) food (to eat) and water (to drink), etc.
"""

position: Array = DISCARD_PILE_COORDS[None] # IntArray['b 2']
"""The (row, column) position of the entity in the grid, defaults to the discard pile (-1, -1)"""
requires: Array = EMPTY_POCKET_ID[None] # IntArray['b']
"""The id of the item required to consume this item. If set, it must be >= 1."""
replacement: Array = jnp.zeros((1,), dtype=jnp.float32) # IntArray['b']
"""The grid signature to replace the item with, usually 0 (floor). If set, it must be >= 1."""

@property
def tag(self) -> Array: # -> IntArray['b']
return self.requires


class State(struct.PyTreeNode):
"""The Markovian state of the environment"""

key: KeyArray
"""The random number generator state"""
grid: Array
"""The base map of the environment that remains constant throughout the training"""
cache: RenderingCache
"""The rendering cache to speed up rendering"""
player: Player # we can potentially extend this to multiple players easily
"""The player entity"""
goals: Goal = Goal()
"""The goal entity, batched over the number of goals"""
keys: Key = Key()
"""The key entity, batched over the number of keys"""
doors: Door = Door()
"""The door entity, batched over the number of doors"""

def get_positions(self, axis: int = -1) -> Array:
return jnp.stack(
[
*self.keys.position,
*self.doors.position,
*self.goals.position,
self.player.position,
],
axis=axis,
)

def get_tags(self, axis: int = -1) -> Array:
return jnp.stack(
[
*self.keys.tag,
*self.doors.tag,
*self.goals.tag,
self.player.tag,
],
axis=axis,
)

def get_tiles(self, tiles_registry: Dict[str, Array], axis: int = 0) -> Array:
return jnp.stack(
[
*([tiles_registry["key"]] * len(self.keys.position)),
*([tiles_registry["door"]] * len(self.doors.position)),
*([tiles_registry["goal"]] * len(self.goals.position)),
tiles_registry["player"],
],
axis=axis,
)
class Holder(struct.PyTreeNode):
pocket: Array = EMPTY_POCKET_ID
"""The id of the item in the pocket (0 if empty)"""
Loading

0 comments on commit 2dcf7cc

Please sign in to comment.