diff --git a/.github/workflows/CD.yml b/.github/workflows/CD.yml index e840b79..d800b30 100644 --- a/.github/workflows/CD.yml +++ b/.github/workflows/CD.yml @@ -1,6 +1,10 @@ name: CD -on: workflow_dispatch +on: + workflow_dispatch: + push: + branches: + - "main" jobs: release: diff --git a/navix/_version.py b/navix/_version.py index 118ded0..011f80c 100644 --- a/navix/_version.py +++ b/navix/_version.py @@ -18,5 +18,5 @@ # under the License. -__version__ = "0.3.6" +__version__ = "0.3.7" __version_info__ = tuple(int(i) for i in __version__.split(".") if i.isdigit()) diff --git a/navix/grid.py b/navix/grid.py index 08b0d45..09c912c 100644 --- a/navix/grid.py +++ b/navix/grid.py @@ -155,49 +155,33 @@ def two_rooms(height: int, width: int, key: KeyArray) -> Tuple[Array, Array]: def crop(grid: Array, origin: Array, direction: Array, radius: int) -> Array: input_shape = grid.shape - max_dim = max(input_shape) - - # pad to square and ensure non out of bounds - padding = [] - for d in input_shape[:2]: - pad = max_dim - d - pad, rem = divmod(pad, 2) - pad = (pad + rem + radius, pad + radius) - padding.append(pad) + + # pad with radius + padding = [(radius, radius), (radius, radius)] for _ in range(len(input_shape) - 2): padding.append((0, 0)) padded = jnp.pad(grid, padding, constant_values=0) - origin = origin + jnp.asarray((padding[0][0] - radius, padding[1][0] - radius)) - height, width = (padded.shape[0] - radius * 2, padded.shape[1] - radius * 2) - # rotate - rotated, centre = jax.lax.switch( + # translate the grid such that the agent is `radius` away from the top and left edges + translated = jnp.roll(padded, -jnp.asarray(origin), axis=(0, 1)) + + # crop such that the agent is in the centre of the grid + cropped = translated[: 2 * radius + 1, : 2 * radius + 1] + + # rotate such that the agent is facing north + rotated = jax.lax.switch( direction, ( - lambda x: ( - jnp.rot90(x, 1), - (width - 1 - origin[1], origin[0]), - ), # 0 = transpose, 1 = flip - lambda x: ( - jnp.rot90(x, 2), - (height - 1 - origin[0], width - 1 - origin[1]), - ), # 0 = flip, 1 = flip - lambda x: ( - jnp.rot90(x, 3), - (origin[1], height - 1 - origin[0]), - ), # 0 = flip, 1 = transpose - lambda x: (x, (origin[0], origin[1])), + lambda x: jnp.rot90(x, 1), # 0 = transpose, 1 = flip + lambda x: jnp.rot90(x, 2), # 0 = flip, 1 = flip + lambda x: jnp.rot90(x, 3), # 0 = flip, 1 = transpose + lambda x: x, ), - padded, + cropped, ) - # translate - translated = jnp.roll(rotated, -jnp.asarray(centre), axis=(0, 1)) - - # crop - cropped = translated[: radius + 1, : 2 * radius + 1] - + cropped = rotated[:radius + 1] return jnp.asarray(cropped, dtype=grid.dtype)