Skip to content

Commit

Permalink
[BugFix] extract the info dict from a list (#1131)
Browse files Browse the repository at this point in the history
Co-authored-by: vmoens <[email protected]>
  • Loading branch information
xmaples and vmoens authored May 5, 2023
1 parent 47579aa commit 24abc75
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 18 deletions.
1 change: 1 addition & 0 deletions .circleci/unittest/linux_libs/scripts_gym/batch_scripts.sh
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ do
echo "Testing gym version: ${GYM_VERSION}"
pip3 install 'gym[accept-rom-license]'==$GYM_VERSION
pip3 install 'gym[atari]'==$GYM_VERSION
pip3 install gym-super-mario-bros
$DIR/run_test.sh

# delete the conda copy
Expand Down
82 changes: 65 additions & 17 deletions test/test_libs.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,24 +100,24 @@


@pytest.mark.skipif(not _has_gym, reason="no gym library found")
@pytest.mark.parametrize(
"env_name",
[
PONG_VERSIONED,
# PENDULUM_VERSIONED,
HALFCHEETAH_VERSIONED,
],
)
@pytest.mark.parametrize("frame_skip", [1, 3])
@pytest.mark.parametrize(
"from_pixels,pixels_only",
[
[False, False],
[True, True],
[True, False],
],
)
class TestGym:
@pytest.mark.parametrize(
"env_name",
[
PONG_VERSIONED,
# PENDULUM_VERSIONED,
HALFCHEETAH_VERSIONED,
],
)
@pytest.mark.parametrize("frame_skip", [1, 3])
@pytest.mark.parametrize(
"from_pixels,pixels_only",
[
[False, False],
[True, True],
[True, False],
],
)
def test_gym(self, env_name, frame_skip, from_pixels, pixels_only):
if env_name == PONG_VERSIONED and not from_pixels:
# raise pytest.skip("already pixel")
Expand Down Expand Up @@ -176,6 +176,23 @@ def test_gym(self, env_name, frame_skip, from_pixels, pixels_only):
assert final_seed0 == final_seed2
assert_allclose_td(tdrollout[0], rollout2, rtol=RTOL, atol=ATOL)

@pytest.mark.parametrize(
"env_name",
[
PONG_VERSIONED,
# PENDULUM_VERSIONED,
HALFCHEETAH_VERSIONED,
],
)
@pytest.mark.parametrize("frame_skip", [1, 3])
@pytest.mark.parametrize(
"from_pixels,pixels_only",
[
[False, False],
[True, True],
[True, False],
],
)
def test_gym_fake_td(self, env_name, frame_skip, from_pixels, pixels_only):
if env_name == PONG_VERSIONED and not from_pixels:
# raise pytest.skip("already pixel")
Expand All @@ -195,6 +212,37 @@ def test_gym_fake_td(self, env_name, frame_skip, from_pixels, pixels_only):
)
check_env_specs(env)

def test_info_reader(self):
try:
import gym_super_mario_bros as mario_gym
except ImportError as err:
try:
import gym

# with 0.26 we must have installed gym_super_mario_bros
# Since we capture the skips as errors, we raise a skip in this case
# Otherwise, we just return
if (
version.parse("0.26.0")
<= version.parse(gym.__version__)
< version.parse("0.27.0")
):
raise pytest.skip(f"no super mario bros: error=\n{err}")
except ImportError:
pass
return

env = mario_gym.make("SuperMarioBros-v0", apply_api_compatibility=True)
env = GymWrapper(env)

def info_reader(info, tensordict):
assert isinstance(info, dict) # failed before bugfix

env.info_dict_reader = info_reader
env.reset()
env.rand_step()
env.rollout(3)


@implement_for("gym", None, "0.26")
def _make_gym_environment(env_name): # noqa: F811
Expand Down
2 changes: 1 addition & 1 deletion torchrl/envs/gym_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def _reset(
obs, *other = self._output_transform(reset_data)
info = None
if len(other) == 1:
info = other
info = other[0]

tensordict_out = TensorDict(
source=self.read_obs(obs),
Expand Down

0 comments on commit 24abc75

Please sign in to comment.