Skip to content

Commit

Permalink
Merge pull request #87 from epignatelli/fix-env-create
Browse files Browse the repository at this point in the history
Fix env create
  • Loading branch information
epignatelli authored Jul 8, 2024
2 parents d926b98 + 583d890 commit 73f924b
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 22 deletions.
2 changes: 1 addition & 1 deletion navix/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,5 @@
# under the License.


__version__ = "0.6.16"
__version__ = "0.6.17"
__version_info__ = tuple(int(i) for i in __version__.split(".") if i.isdigit())
12 changes: 6 additions & 6 deletions navix/environments/empty.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,10 @@ def _reset(self, key: Array, cache: Union[RenderingCache, None] = None) -> Times
height=8,
width=8,
random_start=False,
*args,
observation_fn=kwargs.pop("observation_fn", observations.symbolic),
reward_fn=kwargs.pop("reward_fn", rewards.on_goal_reached),
termination_fn=kwargs.pop("termination_fn", terminations.on_goal_reached),
*args,
**kwargs,
),
)
Expand All @@ -129,10 +129,10 @@ def _reset(self, key: Array, cache: Union[RenderingCache, None] = None) -> Times
height=16,
width=16,
random_start=False,
*args,
observation_fn=kwargs.pop("observation_fn", observations.symbolic),
reward_fn=kwargs.pop("reward_fn", rewards.on_goal_reached),
termination_fn=kwargs.pop("termination_fn", terminations.on_goal_reached),
*args,
**kwargs,
),
)
Expand All @@ -142,10 +142,10 @@ def _reset(self, key: Array, cache: Union[RenderingCache, None] = None) -> Times
height=5,
width=5,
random_start=True,
*args,
observation_fn=kwargs.pop("observation_fn", observations.symbolic),
reward_fn=kwargs.pop("reward_fn", rewards.on_goal_reached),
termination_fn=kwargs.pop("termination_fn", terminations.on_goal_reached),
*args,
**kwargs,
),
)
Expand All @@ -155,10 +155,10 @@ def _reset(self, key: Array, cache: Union[RenderingCache, None] = None) -> Times
height=6,
width=6,
random_start=True,
*args,
observation_fn=kwargs.pop("observation_fn", observations.symbolic),
reward_fn=kwargs.pop("reward_fn", rewards.on_goal_reached),
termination_fn=kwargs.pop("termination_fn", terminations.on_goal_reached),
*args,
**kwargs,
),
)
Expand All @@ -168,10 +168,10 @@ def _reset(self, key: Array, cache: Union[RenderingCache, None] = None) -> Times
height=8,
width=8,
random_start=True,
*args,
observation_fn=kwargs.pop("observation_fn", observations.symbolic),
reward_fn=kwargs.pop("reward_fn", rewards.on_goal_reached),
termination_fn=kwargs.pop("termination_fn", terminations.on_goal_reached),
*args,
**kwargs,
),
)
Expand All @@ -181,10 +181,10 @@ def _reset(self, key: Array, cache: Union[RenderingCache, None] = None) -> Times
height=16,
width=16,
random_start=True,
*args,
observation_fn=kwargs.pop("observation_fn", observations.symbolic),
reward_fn=kwargs.pop("reward_fn", rewards.on_goal_reached),
termination_fn=kwargs.pop("termination_fn", terminations.on_goal_reached),
*args,
**kwargs,
),
)
24 changes: 12 additions & 12 deletions navix/environments/key_corridor.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,72 +149,72 @@ def _reset(self, key: Array, cache: Union[RenderingCache, None] = None) -> Times
register_env(
"Navix-KeyCorridorS3R1-v0",
lambda *args, **kwargs: KeyCorridor.create(
*args,
**kwargs,
height=3,
width=7,
observation_fn=kwargs.pop("observation_fn", observations.symbolic),
reward_fn=kwargs.pop("reward_fn", rewards.on_goal_reached),
termination_fn=kwargs.pop("termination_fn", terminations.on_goal_reached),
*args,
**kwargs,
),
)
register_env(
"Navix-KeyCorridorS3R2-v0",
lambda *args, **kwargs: KeyCorridor.create(
*args,
**kwargs,
height=5,
width=7,
observation_fn=kwargs.pop("observation_fn", observations.symbolic),
reward_fn=kwargs.pop("reward_fn", rewards.on_goal_reached),
termination_fn=kwargs.pop("termination_fn", terminations.on_goal_reached),
*args,
**kwargs,
),
)
register_env(
"Navix-KeyCorridorS3R3-v0",
lambda *args, **kwargs: KeyCorridor.create(
*args,
**kwargs,
height=7,
width=7,
observation_fn=kwargs.pop("observation_fn", observations.symbolic),
reward_fn=kwargs.pop("reward_fn", rewards.on_goal_reached),
termination_fn=kwargs.pop("termination_fn", terminations.on_goal_reached),
*args,
**kwargs,
),
)
register_env(
"Navix-KeyCorridorS4R3-v0",
lambda *args, **kwargs: KeyCorridor.create(
*args,
**kwargs,
height=10,
width=10,
observation_fn=kwargs.pop("observation_fn", observations.symbolic),
reward_fn=kwargs.pop("reward_fn", rewards.on_goal_reached),
termination_fn=kwargs.pop("termination_fn", terminations.on_goal_reached),
*args,
**kwargs,
),
)
register_env(
"Navix-KeyCorridorS5R3-v0",
lambda *args, **kwargs: KeyCorridor.create(
*args,
**kwargs,
height=13,
width=13,
observation_fn=kwargs.pop("observation_fn", observations.symbolic),
reward_fn=kwargs.pop("reward_fn", rewards.on_goal_reached),
termination_fn=kwargs.pop("termination_fn", terminations.on_goal_reached),
*args,
**kwargs,
),
)
register_env(
"Navix-KeyCorridorS6R3-v0",
lambda *args, **kwargs: KeyCorridor.create(
*args,
**kwargs,
height=16,
width=16,
observation_fn=kwargs.pop("observation_fn", observations.symbolic),
reward_fn=kwargs.pop("reward_fn", rewards.on_goal_reached),
termination_fn=kwargs.pop("termination_fn", terminations.on_goal_reached),
*args,
**kwargs,
),
)
2 changes: 1 addition & 1 deletion navix/rendering/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
SPRITES_DIR = os.path.normpath(
os.path.join(__file__, "..", "..", "..", "assets", "sprites")
)
MIN_TILE_SIZE = 8
MIN_TILE_SIZE = 32
TILE_SIZE = MIN_TILE_SIZE


Expand Down
14 changes: 12 additions & 2 deletions tests/test_environments.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,12 @@

def test_room():
def f():
env = nx.environments.Room.create(height=3, width=3, max_steps=8)
env = nx.environments.Room.create(
height=3,
width=3,
max_steps=8,
observation_fn=nx.observations.symbolic_first_person,
)
key = jax.random.PRNGKey(4)
reset = jax.jit(env._reset)
step = jax.jit(env.step)
Expand Down Expand Up @@ -35,7 +40,12 @@ def f():

def test_keydoor():
def f():
env = nx.environments.DoorKey.create(height=5, width=10, max_steps=8)
env = nx.environments.DoorKey.create(
height=5,
width=10,
max_steps=8,
observation_fn=nx.observations.symbolic_first_person,
)
key = jax.random.PRNGKey(1)
reset = jax.jit(env._reset)
step = jax.jit(env.step)
Expand Down

0 comments on commit 73f924b

Please sign in to comment.