Skip to content

Commit

Permalink
Merge pull request #31 from epignatelli/dev/fp
Browse files Browse the repository at this point in the history
Refactoring the rendering engine
  • Loading branch information
epignatelli authored Jun 25, 2023
2 parents d850a0f + d041edc commit eb3f676
Show file tree
Hide file tree
Showing 22 changed files with 975 additions and 371 deletions.
20 changes: 16 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,16 @@
**[Quickstart](#what-is-navix)** | **[Installation](#installation)** | **[Examples](#examples)** | **[Cite](#cite)**

## What is NAVIX?
NAVIX is [minigrid](https://github.com/Farama-Foundation/Minigrid) in JAX, **>10000x** faster with Autograd and XLA support.
---
NAVIX is [minigrid](https://github.com/Farama-Foundation/Minigrid) in JAX, **>1000x** faster with Autograd and XLA support.
You can see a superficial performance comparison [here](docs/profiling.ipynb).

The library is in active development, and we are working on adding more environments and features.
If you want join the development and contribute, please [open a discussion](https://github.com/epignatelli/navix/discussions/new?category=general) and let's have a chat!


## Installation
---
We currently support the OSs supported by JAX.
You can find a description [here](https://github.com/google/jax#installation).

Expand All @@ -22,13 +27,20 @@ You might want to follow the same guide to install jax for your faviourite accel
[TPU](https://github.com/google/jax#pip-installation-colab-tpu)
).

Then, install `navix` and its dependencies with:
- ### Stable
Then, install the stable version of `navix` and its dependencies with:
```bash
pip install navix
```

---
- ### Nightly
Or, if you prefer to install the latest version from source:
```bash
pip install git+https://github.com/epignatelli/navix
```

## Examples
---

### XLA compilation
One straightforward use case is to accelerate the computation of the environment with XLA compilation.
Expand Down Expand Up @@ -64,6 +76,7 @@ TODO(epignatelli): add example.


## Cite
---
If you use `helx` please consider citing it as:

```bibtex
Expand All @@ -75,4 +88,3 @@ If you use `helx` please consider citing it as:
journal = {GitHub repository},
howpublished = {\url{https://github.com/epignatelli/navix}}
}
```
91 changes: 41 additions & 50 deletions docs/profiling.ipynb → docs/performance.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -38,51 +38,29 @@
"import jax.numpy as jnp\n",
"import navix as nx\n",
"\n",
"import gymnasium as gym\n",
"import minigrid\n",
"import random\n",
"import time\n",
"\n",
"N_TIMESTEPS = 10_000\n",
"from timeit import timeit\n",
"\n",
"\n",
"N_TIMEIT_LOOPS = 5\n",
"N_TIMESTEPS = 100\n",
"N_SEEDS = 10\n",
"\n",
"\n",
"def profile_navix(seed):\n",
" env = nx.environments.Room(16, 16, 8)\n",
" env = nx.environments.Room(16, 16, 8, observation_fn=nx.observations.rgb)\n",
" key = jax.random.PRNGKey(seed)\n",
" timestep = env.reset(key)\n",
" actions = jax.random.randint(key, (N_TIMESTEPS,), 0, 6)\n",
"\n",
" def body_fun(carry, x):\n",
" timestep = carry\n",
" action = x\n",
" timestep = env.step(timestep, jnp.asarray(action))\n",
" return timestep, ()\n",
"\n",
" return jax.lax.scan(body_fun, timestep, jnp.asarray(actions, dtype=jnp.int32))[0]\n",
" for i in range(N_TIMESTEPS):\n",
" timestep = env.step(timestep, actions[i])\n",
"\n",
"\n",
"f = jax.jit(jax.vmap(profile_navix))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "44SbP_tOp1dP"
},
"outputs": [],
"source": [
"# running 10_000 seeds in parallel\n",
"%timeit -n 5 -r 1 f(jnp.arange(10_000)).state.grid.block_until_ready()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "iSeCh6H4qRdx"
},
"outputs": [],
"source": [
"import gymnasium as gym\n",
"import minigrid\n",
"import random\n",
" return timestep\n",
"\n",
"\n",
"def profile_minigrid(seed):\n",
Expand All @@ -95,19 +73,31 @@
" if terminated or truncated:\n",
" observation, info = env.reset()\n",
" env.close()\n",
" return observation"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "1b4-SrmWsYgs"
},
"outputs": [],
"source": [
"# running 1 seed\n",
"%timeit -n 5 -r 1 profile_minigrid(0)"
" return observation\n",
"\n",
"\n",
"if __name__ == \"__main__\":\n",
" # profile navix\n",
" print(\n",
" \"Profiling navix, N_SEEDS = {}, N_TIMESTEPS = {}\".format(N_SEEDS, N_TIMESTEPS)\n",
" )\n",
" seeds = jnp.arange(N_SEEDS)\n",
"\n",
" print(\"\\tCompiling...\")\n",
" start = time.time()\n",
" f = jax.jit(jax.vmap(profile_navix)).lower(seeds).compile()\n",
" print(\"\\tCompiled in {:.2f}s\".format(time.time() - start))\n",
"\n",
" print(\"\\tRunning ...\")\n",
" res_navix = timeit(\n",
" lambda: f(seeds).state.grid.block_until_ready(), number=N_TIMEIT_LOOPS\n",
" )\n",
" print(res_navix)\n",
"\n",
" # profile minigrid\n",
" print(\"Profiling minigrid, N_SEEDS = 1, N_TIMESTEPS = {}\".format(N_TIMESTEPS))\n",
" res_minigrid = timeit(lambda: profile_minigrid(0), number=N_TIMEIT_LOOPS)\n",
" print(res_minigrid)"
]
},
{
Expand All @@ -132,7 +122,8 @@
"name": "python3"
},
"language_info": {
"name": "python"
"name": "python",
"version": "3.9.16"
}
},
"nbformat": 4,
Expand Down
2 changes: 1 addition & 1 deletion navix/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,5 @@
# under the License.


__version__ = "0.2.2"
__version__ = "0.3.2"
__version_info__ = tuple(int(i) for i in __version__.split(".") if i.isdigit())
99 changes: 49 additions & 50 deletions navix/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,31 +24,23 @@
import jax.numpy as jnp
from jax import Array

from .components import Consumable, Pickable, State
from .components import Door, Key, State, DISCARD_PILE_COORDS
from .grid import translate, rotate


DIRECTIONS = {0: "east", 1: "south", 2: "west", 3: "north"}


def _rotate(state: State, spin: int) -> State:
direction = (state.player.direction + spin) % 4
direction = rotate(state.player.direction, spin)
player = state.player.replace(direction=direction)
return state.replace(player=player)


def _translate(position: Array, direction: Array) -> Array:
moves = (
lambda position: position + jnp.asarray((0, 1)), # east
lambda position: position + jnp.asarray((1, 0)), # south
lambda position: position + jnp.asarray((0, -1)), # west
lambda position: position + jnp.asarray((-1, 0)), # north
)
return jax.lax.switch(direction, moves, position)


def _move_allowed(state: State, position: Array) -> Array:
def _walkable(state: State, position: Array) -> Array:
# according to the grid
walkable = jnp.equal(state.grid[tuple(position)], 0)

# and not occupied by another non-walkable entity
occupied_keys = jax.vmap(lambda x: jnp.array_equal(x, position))(
state.keys.position
Expand All @@ -57,12 +49,13 @@ def _move_allowed(state: State, position: Array) -> Array:
state.doors.position
)
occupied = jnp.any(jnp.concatenate([occupied_keys, occupied_doors]))
# return: if walkable and not occupied
return jnp.logical_and(walkable, jnp.logical_not(occupied))


def _move(state: State, direction: Array) -> State:
new_position = _translate(state.player.position, direction)
can_move = _move_allowed(state, new_position)
new_position = translate(state.player.position, direction)
can_move = _walkable(state, new_position)
new_position = jnp.where(can_move, new_position, state.player.position)
player = state.player.replace(position=new_position)
return state.replace(player=player)
Expand Down Expand Up @@ -107,46 +100,52 @@ def left(state: State) -> State:
return _move(state, state.player.direction + 3)


def _one_many_position_equal(a: Array, b: Array) -> Array:
assert a.ndim == 1 and b.ndim == 2
is_equal = jnp.sum(a[None] - b, axis=-1) == 0
assert is_equal.shape == (b.shape[0],)
return is_equal


def pickup(state: State) -> State:
position_in_front = _translate(state.player.position, state.player.direction)

def _update(key: Pickable) -> Tuple[Array, Pickable]:
match = jnp.array_equal(position_in_front, key.position)
# update player's pocket
pocket = jnp.where(match, key.id, state.player.pocket)
# set to (-1, -1) the position of the key that was picked up
unset_position = jnp.asarray((-1, -1))
position = jnp.where(match, unset_position, key.position)
key = key.replace(position=position)
return pocket, key

pockets, keys = jax.vmap(_update)(state.keys)
pocket = jnp.max(pockets, axis=0)
player = state.player.replace(pocket=pocket)
position_in_front = translate(state.player.position, state.player.direction)

key_found = _one_many_position_equal(position_in_front, state.keys.position)

# update keys
positions = jnp.where(key_found, DISCARD_PILE_COORDS, state.keys.position)
keys = state.keys.replace(position=positions)

# update player's pocket, if the pocket has something else, we overwrite it
key = jnp.sum(state.keys.id * key_found, dtype=jnp.int32)
player = jax.lax.cond(jnp.any(key_found), lambda: state.player.replace(pocket=key), lambda: state.player)

return state.replace(player=player, keys=keys)


def open(state: State) -> State:
position_in_front = _translate(state.player.position, state.player.direction)

def _update(door: Consumable) -> Tuple[Array, Consumable]:
match = jnp.array_equal(position_in_front, door.position)
replacement = jnp.asarray((match - 1) * door.replacement, dtype=jnp.int32)

# update grid
grid = jnp.zeros_like(state.grid).at[tuple(door.position)].set(replacement)

# set to (-1, -1) the position of the door that was opened
unset_position = jnp.asarray((-1, -1))
position = jnp.where(match, unset_position, door.position)
door = door.replace(position=position)
return grid, door

grid, doors = jax.vmap(_update)(state.doors)
# the max makes sure that if there was a wall (-1), and it has been opened (x>0)
# we get the new value of the grid
grid = jnp.max(grid, axis=0)
return state.replace(grid=grid, doors=doors)
# get the tile in front of the player
position_in_front = translate(state.player.position, state.player.direction)

# check if there is a door in front of the player
door_found = position_in_front[None] == state.doors.position
# and that, if so, either it does not require a key or the player has the key
requires_key = state.doors.requires != -1
key_match = state.player.pocket == state.doors.requires
can_open = door_found & (key_match | ~requires_key )

# update doors
# TODO(epignatelli): in the future we want to mark the door as open, instead
# and have a different rendering for it
# if the door can be opened, move it to the discard pile
new_positions = jnp.where(can_open, DISCARD_PILE_COORDS, state.doors.position)
doors = state.doors.replace(position=new_positions)

# remove key from player's pocket
pocket = jnp.asarray(state.player.pocket * jnp.any(can_open), dtype=jnp.int32)
player = jax.lax.cond(jnp.any(can_open), lambda: state.player.replace(pocket=pocket), lambda: state.player)

return state.replace(player=player, doors=doors)


# TODO(epignatelli): a mutable dictionary here is dangerous
Expand Down
Loading

0 comments on commit eb3f676

Please sign in to comment.