diff --git a/benchmarks/test_objectives_benchmarks.py b/benchmarks/test_objectives_benchmarks.py index 629d83a6dd3..2e8fe407b3d 100644 --- a/benchmarks/test_objectives_benchmarks.py +++ b/benchmarks/test_objectives_benchmarks.py @@ -1152,4 +1152,7 @@ def loss_and_bw(td): if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() - pytest.main([__file__, "--capture", "no", "--exitfirst", "--benchmark-group-by", "func"] + unknown) + pytest.main( + [__file__, "--capture", "no", "--exitfirst", "--benchmark-group-by", "func"] + + unknown + ) diff --git a/setup.py b/setup.py index 823ec307052..33d9d5c7268 100644 --- a/setup.py +++ b/setup.py @@ -209,7 +209,18 @@ def _main(argv): "dm_control": ["dm_control"], "gym_continuous": ["gymnasium<1.0", "mujoco"], "rendering": ["moviepy<2.0.0"], - "tests": ["pytest", "pyyaml", "pytest-instafail", "scipy"], + "tests": [ + "pytest", + "pyyaml", + "pytest-instafail", + "scipy", + "pytest-mock", + "pytest-cov", + "pytest-benchmark", + "pytest-rerunfailures", + "pytest-error-for-skips", + "", + ], "utils": [ "tensorboard", "wandb", diff --git a/test/test_rlhf.py b/test/test_rlhf.py index 486ddbef127..1c09856dd36 100644 --- a/test/test_rlhf.py +++ b/test/test_rlhf.py @@ -298,12 +298,10 @@ def test_tensordict_tokenizer( "Lettuce in, it's cold out here!", ] } - if not truncation and return_tensordict and max_length == 10: - with pytest.raises(ValueError, match="TensorDict conversion only supports"): - out = process(example) - return out = process(example) - if return_tensordict: + if not truncation and return_tensordict and max_length == 10: + assert out.get("input_ids").shape[-1] == -1 + elif return_tensordict: assert out.get("input_ids").shape[-1] == max_length else: obj = out.get("input_ids") @@ -346,12 +344,10 @@ def test_prompt_tensordict_tokenizer( ], "label": ["right", "wrong", "right", "wrong", "right"], } - if not truncation and return_tensordict and max_length == 10: - with pytest.raises(ValueError, match="TensorDict conversion only supports"): - out = process(example) - return out = process(example) - if return_tensordict: + if not truncation and return_tensordict and max_length == 10: + assert out.get("input_ids").shape[-1] == -1 + elif return_tensordict: assert out.get("input_ids").shape[-1] == max_length else: obj = out.get("input_ids") diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 949f5e3b621..8adf36b0019 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -6,9 +6,9 @@ from __future__ import annotations import abc -import functools import warnings from copy import deepcopy +from functools import partial, wraps from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple import numpy as np @@ -33,6 +33,7 @@ _StepMDP, _terminated_or_truncated, _update_during_reset, + check_env_specs as check_env_specs_func, get_available_libraries, ) @@ -2035,7 +2036,7 @@ def _register_gym( if entry_point is None: entry_point = cls - entry_point = functools.partial( + entry_point = partial( _TorchRLGymWrapper, entry_point=entry_point, info_keys=info_keys, @@ -2084,7 +2085,7 @@ def _register_gym( # noqa: F811 if entry_point is None: entry_point = cls - entry_point = functools.partial( + entry_point = partial( _TorchRLGymWrapper, entry_point=entry_point, info_keys=info_keys, @@ -2138,7 +2139,7 @@ def _register_gym( # noqa: F811 if entry_point is None: entry_point = cls - entry_point = functools.partial( + entry_point = partial( _TorchRLGymWrapper, entry_point=entry_point, info_keys=info_keys, @@ -2195,7 +2196,7 @@ def _register_gym( # noqa: F811 if entry_point is None: entry_point = cls - entry_point = functools.partial( + entry_point = partial( _TorchRLGymWrapper, entry_point=entry_point, info_keys=info_keys, @@ -2254,7 +2255,7 @@ def _register_gym( # noqa: F811 ) if entry_point is None: entry_point = cls - entry_point = functools.partial( + entry_point = partial( _TorchRLGymWrapper, entry_point=entry_point, info_keys=info_keys, @@ -2293,7 +2294,7 @@ def _register_gym( # noqa: F811 if entry_point is None: entry_point = cls - entry_point = functools.partial( + entry_point = partial( _TorchRLGymnasiumWrapper, entry_point=entry_point, info_keys=info_keys, @@ -3422,11 +3423,11 @@ def _get_sync_func(policy_device, env_device): if policy_device is not None and policy_device.type == "cuda": if env_device is None or env_device.type == "cuda": return torch.cuda.synchronize - return functools.partial(torch.cuda.synchronize, device=policy_device) + return partial(torch.cuda.synchronize, device=policy_device) if env_device is not None and env_device.type == "cuda": if policy_device is None: return torch.cuda.synchronize - return functools.partial(torch.cuda.synchronize, device=env_device) + return partial(torch.cuda.synchronize, device=env_device) return torch.cuda.synchronize if torch.backends.mps.is_available(): return torch.mps.synchronize