Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BE] single dim check helper #1192

Merged
merged 2 commits into from
Feb 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 6 additions & 30 deletions tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
_is_shared,
_KEY_ERROR,
_LOCK_ERROR,
_maybe_correct_neg_dim,
_mismatch_keys,
_NON_STR_KEY_ERR,
_NON_STR_KEY_TUPLE_ERR,
Expand Down Expand Up @@ -880,8 +881,7 @@ def all(self, dim: int = None) -> bool | TensorDictBase:
"smaller than tensordict.batch_dims"
)
if dim is not None:
if dim < 0:
dim = self.batch_dims + dim
dim = _maybe_correct_neg_dim(dim, self.batch_size)

names = None
if self._has_names():
Expand All @@ -903,8 +903,7 @@ def any(self, dim: int = None) -> bool | TensorDictBase:
"smaller than tensordict.batch_dims"
)
if dim is not None:
if dim < 0:
dim = self.batch_dims + dim
dim = _maybe_correct_neg_dim(dim, self.batch_size)

names = None
if self._has_names():
Expand Down Expand Up @@ -982,14 +981,7 @@ def proc_dim(dim, batch_dims, tuple_ok=True):
for _d in proc_dim(d, batch_dims, tuple_ok=False)
)
return dim
if dim >= batch_dims or dim < -batch_dims:
raise RuntimeError(
"dim must be greater than or equal to -tensordict.batch_dims and "
"smaller than tensordict.batch_dims"
)
if dim < 0:
return (batch_dims + dim,)
return (dim,)
return (_maybe_correct_neg_dim(dim, None, batch_dims),)

dim_needs_proc = (dim is not NO_DEFAULT) and (dim not in ("feature",))
if dim_needs_proc:
Expand Down Expand Up @@ -1726,13 +1718,7 @@ def split(self, split_size: int | list[int], dim: int = 0) -> list[TensorDictBas
WRONG_TYPE = "split(): argument 'split_size' must be int or list of ints"
batch_size = self.batch_size
batch_sizes = []
batch_dims = len(batch_size)
if dim < 0:
dim = len(batch_size) + dim
if dim >= batch_dims or dim < 0:
raise IndexError(
f"Dimension out of range (expected to be in range of [-{self.batch_dims}, {self.batch_dims - 1}], but got {dim})"
)
dim = _maybe_correct_neg_dim(dim, batch_size)
max_size = batch_size[dim]
if isinstance(split_size, int):
idx0 = 0
Expand Down Expand Up @@ -2007,17 +1993,7 @@ def _squeeze(tensor):
propagate_lock=True,
)
# make the dim positive
if dim < 0:
newdim = self.batch_dims + dim
else:
newdim = dim

if (newdim >= self.batch_dims) or (newdim < 0):
raise RuntimeError(
f"squeezing is allowed for dims comprised between "
f"`-td.batch_dims` and `td.batch_dims - 1` only. Got "
f"dim={dim} with a batch size of {self.batch_size}."
)
newdim = _maybe_correct_neg_dim(dim, batch_size)
if batch_size[dim] != 1:
return self
batch_size = list(batch_size)
Expand Down
33 changes: 6 additions & 27 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
_KEY_ERROR,
_lock_warn,
_make_dtype_promotion,
_maybe_correct_neg_dim,
_parse_to,
_pass_through,
_pass_through_cls,
Expand Down Expand Up @@ -3343,13 +3344,7 @@ def unbind(self, dim: int) -> tuple[T, ...]:
tensor([4, 5, 6, 7])

"""
batch_dims = self.batch_dims
if dim < -batch_dims or dim >= batch_dims:
raise RuntimeError(
f"the dimension provided ({dim}) is beyond the tensordict dimensions ({self.ndim})."
)
if dim < 0:
dim = batch_dims + dim
dim = _maybe_correct_neg_dim(dim, self.batch_size)
results = self._unbind(dim)
if self._is_memmap or self._is_shared:
for result in results:
Expand Down Expand Up @@ -7385,12 +7380,7 @@ def unflatten(self, dim, unflattened_size):
>>> td_unflat = td_flat.unflatten(0, [3, 4])
>>> assert (td == td_unflat).all()
"""
if dim < 0:
dim = self.ndim + dim
if dim < 0:
raise ValueError(
f"Incompatible dim {dim} for tensordict with shape {self.shape}."
)
dim = _maybe_correct_neg_dim(dim, self.batch_size)

def unflatten(tensor):
return torch.unflatten(
Expand Down Expand Up @@ -8935,11 +8925,7 @@ def _map(
iterable: bool,
):
num_workers = pool._processes
dim_orig = dim
if dim < 0:
dim = self.ndim + dim
if dim < 0 or dim >= self.ndim:
raise ValueError(f"Got incompatible dimension {dim_orig}")
dim = _maybe_correct_neg_dim(dim, self.batch_size)

self_split = _split_tensordict(
self,
Expand Down Expand Up @@ -9567,18 +9553,11 @@ def softmax(self, dim: int, dtype: torch.dtype | None = None): # noqa: D417

"""
if isinstance(dim, int):
if dim < 0:
new_dim = self.ndim + dim
else:
new_dim = dim
dim = _maybe_correct_neg_dim(dim, self.batch_size)
else:
raise ValueError(f"Expected dim of type int, got {type(dim)}.")
if (new_dim < 0) or (new_dim >= self.ndim):
raise ValueError(
f"The dimension {dim} is incompatible with a tensordict with batch_size {self.batch_size}."
)
return self._fast_apply(
lambda x: torch.softmax(x, dim=new_dim, dtype=dtype),
lambda x: torch.softmax(x, dim=dim, dtype=dtype),
)

def log10(self) -> T:
Expand Down
21 changes: 21 additions & 0 deletions tensordict/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2925,3 +2925,24 @@ def set_mode(self, type: Any | None) -> None:
cm = self._lock if not is_compiling() else nullcontext()
with cm:
self._mode = type


def _maybe_correct_neg_dim(
dim: int, shape: torch.Size | None, ndim: int | None = None
) -> int:
"""Corrects neg dim to pos."""
if ndim is None:
ndim = len(shape)
if dim < 0:
new_dim = ndim + dim
else:
new_dim = dim
if new_dim < 0 or new_dim >= ndim:
if shape is not None:
raise IndexError(
f"Incompatible dim {new_dim} for tensordict with shape {shape}."
)
raise IndexError(
f"Incompatible dim {new_dim} for tensordict with batch dims {ndim}."
)
return new_dim
6 changes: 3 additions & 3 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -2566,7 +2566,7 @@ def test_split_with_empty_tensordict(self):
def test_split_with_invalid_arguments(self):
td = TensorDict({"a": torch.zeros(2, 1)}, [])
# Test empty batch size
with pytest.raises(IndexError, match="Dimension out of range"):
with pytest.raises(IndexError, match="Incompatible dim"):
td.split(1, 0)

td = TensorDict({}, [3, 2])
Expand All @@ -2587,9 +2587,9 @@ def test_split_with_invalid_arguments(self):
td.split([1, 1], 0)

# Test invalid dimension input
with pytest.raises(IndexError, match="Dimension out of range"):
with pytest.raises(IndexError, match="Incompatible dim"):
td.split(1, 2)
with pytest.raises(IndexError, match="Dimension out of range"):
with pytest.raises(IndexError, match="Incompatible dim"):
td.split(1, -3)

def test_split_with_negative_dim(self):
Expand Down
Loading