Skip to content

Commit

Permalink
[Versioning] v0.6.1 fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Nov 14, 2024
1 parent 2db58fd commit 4f2cb07
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 3 deletions.
1 change: 1 addition & 0 deletions benchmarks/test_replaybuffer_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
LazyMemmapStorage,
LazyTensorStorage,
ListStorage,
ReplayBuffer,
TensorDictPrioritizedReplayBuffer,
TensorDictReplayBuffer,
)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def _main(argv):
if is_nightly:
tensordict_dep = "tensordict-nightly"
else:
tensordict_dep = "tensordict>=0.6.1"
tensordict_dep = "tensordict>=0.6.2"

if is_nightly:
version = get_nightly_version()
Expand Down
9 changes: 8 additions & 1 deletion test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,14 @@
)
from pytorch.rl.test.mocking_classes import CountingEnv
else:
from _utils_internal import CARTPOLE_VERSIONED, get_default_devices, make_tc
from _utils_internal import (
capture_log_records,
CARTPOLE_VERSIONED,
get_default_devices,
make_tc,
)
from mocking_classes import CountingEnv

from packaging import version
from packaging.version import parse
from tensordict import (
Expand Down Expand Up @@ -119,6 +125,7 @@
_has_gym = importlib.util.find_spec("gym") is not None
_has_snapshot = importlib.util.find_spec("torchsnapshot") is not None
_os_is_windows = sys.platform == "win32"
TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version)

torch_2_3 = version.parse(
".".join([str(s) for s in version.parse(str(torch.__version__)).release])
Expand Down
2 changes: 1 addition & 1 deletion test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import torch

if os.getenv("PYTORCH_TEST_FBCODE"):
from pytorch.rl.test._utils_internal import capture_log_records, get_default_devices
from pytorch.rl.test._utils_internal import get_default_devices
else:
from _utils_internal import get_default_devices
from torchrl._utils import _rng_decorator, get_binary_env_var, implement_for
Expand Down

0 comments on commit 4f2cb07

Please sign in to comment.