Skip to content

Commit

Permalink
Merge pull request #42 from epignatelli/obs/crop
Browse files Browse the repository at this point in the history
perf(obs): improve performance of `grid.crop`
  • Loading branch information
epignatelli authored Jul 13, 2023
2 parents be4668e + 937e626 commit 5832d9d
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 35 deletions.
6 changes: 5 additions & 1 deletion .github/workflows/CD.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
name: CD

on: workflow_dispatch
on:
workflow_dispatch:
push:
branches:
- "main"

jobs:
release:
Expand Down
2 changes: 1 addition & 1 deletion navix/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,5 @@
# under the License.


__version__ = "0.3.6"
__version__ = "0.3.7"
__version_info__ = tuple(int(i) for i in __version__.split(".") if i.isdigit())
50 changes: 17 additions & 33 deletions navix/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,49 +155,33 @@ def two_rooms(height: int, width: int, key: KeyArray) -> Tuple[Array, Array]:

def crop(grid: Array, origin: Array, direction: Array, radius: int) -> Array:
input_shape = grid.shape
max_dim = max(input_shape)

# pad to square and ensure non out of bounds
padding = []
for d in input_shape[:2]:
pad = max_dim - d
pad, rem = divmod(pad, 2)
pad = (pad + rem + radius, pad + radius)
padding.append(pad)

# pad with radius
padding = [(radius, radius), (radius, radius)]
for _ in range(len(input_shape) - 2):
padding.append((0, 0))

padded = jnp.pad(grid, padding, constant_values=0)
origin = origin + jnp.asarray((padding[0][0] - radius, padding[1][0] - radius))
height, width = (padded.shape[0] - radius * 2, padded.shape[1] - radius * 2)

# rotate
rotated, centre = jax.lax.switch(
# translate the grid such that the agent is `radius` away from the top and left edges
translated = jnp.roll(padded, -jnp.asarray(origin), axis=(0, 1))

# crop such that the agent is in the centre of the grid
cropped = translated[: 2 * radius + 1, : 2 * radius + 1]

# rotate such that the agent is facing north
rotated = jax.lax.switch(
direction,
(
lambda x: (
jnp.rot90(x, 1),
(width - 1 - origin[1], origin[0]),
), # 0 = transpose, 1 = flip
lambda x: (
jnp.rot90(x, 2),
(height - 1 - origin[0], width - 1 - origin[1]),
), # 0 = flip, 1 = flip
lambda x: (
jnp.rot90(x, 3),
(origin[1], height - 1 - origin[0]),
), # 0 = flip, 1 = transpose
lambda x: (x, (origin[0], origin[1])),
lambda x: jnp.rot90(x, 1), # 0 = transpose, 1 = flip
lambda x: jnp.rot90(x, 2), # 0 = flip, 1 = flip
lambda x: jnp.rot90(x, 3), # 0 = flip, 1 = transpose
lambda x: x,
),
padded,
cropped,
)

# translate
translated = jnp.roll(rotated, -jnp.asarray(centre), axis=(0, 1))

# crop
cropped = translated[: radius + 1, : 2 * radius + 1]

cropped = rotated[:radius + 1]
return jnp.asarray(cropped, dtype=grid.dtype)


Expand Down

0 comments on commit 5832d9d

Please sign in to comment.