Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] NJT with lengths #1021

Merged
merged 7 commits into from
Oct 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 17 additions & 4 deletions tensordict/_reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,23 @@ def from_metadata(metadata=metadata_dict, prefix=None):
key: NonTensorData(data, batch_size=batch_size)
for (key, (data, batch_size)) in non_tensor.items()
}
for key, _ in leaves.items():
for key in leaves.keys():
total_key = (key,) if prefix is None else prefix + (key,)
if total_key[-1].startswith("<NJT>"):
nested_values = flat_key_values[total_key]
nested_lengths = None
continue
if total_key[-1].startswith("<NJT_OFFSETS"):
if total_key[-1].startswith("<NJT_LENGTHS>"):
nested_lengths = flat_key_values[total_key]
continue
elif total_key[-1].startswith("<NJT_OFFSETS"):
offsets = flat_key_values[total_key]
key = key.replace("<NJT_OFFSETS>", "")
value = torch.nested.nested_tensor_from_jagged(nested_values, offsets)
value = torch.nested.nested_tensor_from_jagged(
nested_values, offsets=offsets, lengths=nested_lengths
)
del nested_values
del nested_lengths
else:
value = flat_key_values[total_key]
d[key] = value
Expand Down Expand Up @@ -93,10 +100,16 @@ def from_metadata(metadata=metadata, prefix=None):
value = value.view(local_shape)
if key.startswith("<NJT>"):
nested_values = value
nested_lengths = None
continue
elif key.startswith("<NJT_LENGTHS>"):
nested_lengths = value
continue
elif key.startswith("<NJT_OFFSETS>"):
offsets = value
value = torch.nested.nested_tensor_from_jagged(nested_values, offsets)
value = torch.nested.nested_tensor_from_jagged(
nested_values, offsets=offsets, lengths=nested_lengths
)
key = key.replace("<NJT_OFFSETS>", "")
d[key] = value
for k, v in metadata.items():
Expand Down
70 changes: 53 additions & 17 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3590,7 +3590,10 @@ def assign(
values = value.values()
shape = [v if isinstance(v, int) else -1 for v in values.shape]
# Get the offsets
offsets = value.offsets()
offsets = value._offsets
# Get the lengths
lengths = value._lengths

# Now we're saving the two tensors
# We will rely on the fact that the writing order is preserved in python dict
# (since python 3.7). Later, we will read the NJT then the NJT offset in that order
Expand All @@ -3602,9 +3605,22 @@ def assign(
metadata_dict,
values.dtype,
shape,
# values.device,
flat_size,
)
# Lengths
if lengths is not None:
flat_key_values[
_prefix_last_key(total_key, "<NJT_LENGTHS>")
] = lengths
add_single_value(
lengths,
_prefix_last_key(key, "<NJT_LENGTHS>"),
metadata_dict,
lengths.dtype,
lengths.shape,
flat_size,
)
# Offsets
flat_key_values[_prefix_last_key(total_key, "<NJT_OFFSETS>")] = (
offsets
)
Expand All @@ -3614,9 +3630,9 @@ def assign(
metadata_dict,
offsets.dtype,
offsets.shape,
# offsets.device,
flat_size,
)

else:
raise NotImplementedError(
"NST is not supported, please use layout=torch.jagged when building the nested tensor."
Expand Down Expand Up @@ -3785,12 +3801,14 @@ def view_old_as_new(v, oldv):
if num_threads > 0:

def assign(
*,
k,
v,
start,
stop,
njts,
njts_offsets,
njts_lengths,
storage=storage,
non_blocking=non_blocking,
):
Expand All @@ -3810,25 +3828,29 @@ def assign(
new_v = new_v.view(shape)
if k[-1].startswith("<NJT>"):
njts[k] = new_v
elif k[-1].startswith("<NJT_LENGTHS>"):
njts_lengths[k] = new_v
elif k[-1].startswith("<NJT_OFFSETS>"):
njts_offsets[k] = new_v
flat_dict[k] = new_v

njts = {}
njts_offsets = {}
njts_lengths = {}
if num_threads > 1:
executor = ThreadPoolExecutor(num_threads)
r = []
for i, (k, v) in enumerate(flat_dict.items()):
r.append(
executor.submit(
assign,
k,
v,
offsets[i],
offsets[i + 1],
njts,
njts_offsets,
k=k,
v=v,
start=offsets[i],
stop=offsets[i + 1],
njts=njts,
njts_offsets=njts_offsets,
njts_lengths=njts_lengths,
)
)
if not return_early:
Expand All @@ -3841,22 +3863,29 @@ def assign(
else:
for i, (k, v) in enumerate(flat_dict.items()):
assign(
k,
v,
offsets[i],
offsets[i + 1],
njts,
njts_offsets,
k=k,
v=v,
start=offsets[i],
stop=offsets[i + 1],
njts=njts,
njts_offsets=njts_offsets,
njts_lengths=njts_lengths,
)
for njt_key, njt_val in njts.items():
njt_key_offset = njt_key[:-1] + (
njt_key[-1].replace("<NJT>", "<NJT_OFFSETS>"),
)
njt_key_lengths = njt_key[:-1] + (
njt_key[-1].replace("<NJT>", "<NJT_LENGTHS>"),
)
val = torch.nested.nested_tensor_from_jagged(
njt_val, flat_dict[njt_key_offset]
njt_val,
offsets=flat_dict[njt_key_offset],
lengths=flat_dict.get(njt_key_lengths),
)
del flat_dict[njt_key]
del flat_dict[njt_key_offset]
flat_dict.pop(njt_key_lengths, None)
newkey = njt_key[:-1] + (njt_key[-1].replace("<NJT>", ""),)
flat_dict[newkey] = val

Expand Down Expand Up @@ -3896,13 +3925,20 @@ def _view_and_pad(tensor):
elif k[-1].startswith("<NJT>"):
# NJT/NT always comes before offsets/shapes
_nested_values = view_old_as_new(v, oldv)
nt_lengths = None
del flat_dict[k]
elif k[-1].startswith("<NJT_LENGTHS>"):
nt_lengths = view_old_as_new(v, oldv)
del flat_dict[k]
elif k[-1].startswith("<NJT_OFFSETS>"):
newk = k[:-1] + (k[-1].replace("<NJT_OFFSETS>", ""),)
nt_offsets = view_old_as_new(v, oldv)
del flat_dict[k]

flat_dict[newk] = torch.nested.nested_tensor_from_jagged(
_nested_values, nt_offsets
_nested_values,
offsets=nt_offsets,
lengths=nt_lengths,
)
# delete the nested value to make sure that if there was an
# ordering mismatch we wouldn't be looking at the value key of
Expand Down
3 changes: 3 additions & 0 deletions tensordict/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1541,6 +1541,9 @@ def assert_close(
continue
elif not isinstance(input1, torch.Tensor):
continue
if input1.is_nested:
input1 = input1._base
input2 = input2._base
mse = (input1.to(torch.float) - input2.to(torch.float)).pow(2).sum()
mse = mse.div(input1.numel()).sqrt().item()

Expand Down
73 changes: 72 additions & 1 deletion test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import functools
import gc
import importlib.util
import json
import os
import pathlib
Expand All @@ -22,8 +23,8 @@

import numpy as np
import pytest
import tensordict.base as tensordict_base

import tensordict.base as tensordict_base
import torch
from _utils_internal import (
decompose,
Expand All @@ -32,6 +33,7 @@
prod,
TestTensorDictsBase,
)
from packaging import version

from tensordict import (
get_defaults_to_none,
Expand All @@ -42,6 +44,7 @@
TensorDict,
)
from tensordict._lazy import _CustomOpTensorDict
from tensordict._reductions import _reduce_td
from tensordict._td import _SubTensorDict, is_tensor_collection
from tensordict._torch_func import _stack as stack_td
from tensordict.base import _is_leaf_nontensor, _NESTED_TENSORS_AS_LISTS, TensorDictBase
Expand Down Expand Up @@ -89,6 +92,11 @@
_has_h5py = True
except ImportError:
_has_h5py = False
TORCH_VERSION = version.parse(torch.__version__).base_version

_has_onnx = importlib.util.find_spec("onnxruntime", None) is not None

_v2_5 = version.parse(".".join(TORCH_VERSION.split(".")[:3])) >= version.parse("2.5.0")

_IS_OSX = platform.system() == "Darwin"
_IS_WINDOWS = sys.platform == "win32"
Expand Down Expand Up @@ -7852,6 +7860,7 @@ def test_consolidate(self, device, use_file, tmpdir):
batch_size=[1, 3],
)
td = LazyStackedTensorDict(*td.unbind(1), stack_dim=1)

if not use_file:
td_c = td.consolidate()
assert td_c.device == device
Expand All @@ -7864,6 +7873,7 @@ def test_consolidate(self, device, use_file, tmpdir):
assert type(td_c) == type(td) # noqa
assert (td.to(td_c.device) == td_c).all()
assert td_c["d"] == [["a string!"] * 3]

storage = td_c._consolidated["storage"]
storage *= 0
assert (td.to(td_c.device) != td_c).any()
Expand All @@ -7887,6 +7897,67 @@ def check_id(a, b):
torch.utils._pytree.tree_map(check_id, td_c._consolidated, tdload._consolidated)
assert tdload.is_consolidated()

@pytest.mark.skipif(not _v2_5, reason="v2.5 required for this test")
@pytest.mark.parametrize("device", [None, *get_available_devices()])
@pytest.mark.parametrize("use_file", [False, True])
def test_consolidate_njt(self, device, use_file, tmpdir):
td = TensorDict(
{
"a": torch.arange(3).expand(4, 3).clone(),
"b": {"c": torch.arange(3, dtype=torch.double).expand(4, 3).clone()},
"d": "a string!",
"njt": torch.nested.nested_tensor_from_jagged(
torch.arange(10, device=device),
offsets=torch.tensor([0, 2, 5, 8, 10], device=device),
),
"njt_lengths": torch.nested.nested_tensor_from_jagged(
torch.arange(10, device=device),
offsets=torch.tensor([0, 2, 5, 8, 10], device=device),
lengths=torch.tensor([2, 3, 3, 2], device=device),
),
},
device=device,
batch_size=[4],
)

if not use_file:
td_c = td.consolidate()
assert td_c.device == device
else:
filename = Path(tmpdir) / "file.mmap"
td_c = td.consolidate(filename=filename)
assert td_c.device == torch.device("cpu")
assert assert_allclose_td(TensorDict.from_consolidated(filename), td_c)
assert hasattr(td_c, "_consolidated")
assert type(td_c) == type(td) # noqa
assert td_c["d"] == "a string!"
with (
pytest.raises(KeyError)
if td.device != td_c.device and device is not None
else contextlib.nullcontext()
):
# njt.to(device) is currently broken when it has lengths
assert_allclose_td(td.to(td_c.device), td_c)

tdload_make, tdload_data = _reduce_td(td)
tdload = tdload_make(*tdload_data)
assert (td == tdload).all()

td_c = td.consolidate()
tdload_make, tdload_data = _reduce_td(td_c)
tdload = tdload_make(*tdload_data)
assert assert_allclose_td(td, tdload)

def check_id(a, b):
if isinstance(a, (torch.Size, str)):
assert a == b
if isinstance(a, torch.Tensor):
assert (a == b).all()

torch.utils._pytree.tree_map(check_id, td_c._consolidated, tdload._consolidated)
assert tdload.is_consolidated()
assert tdload["njt_lengths"]._lengths is not None

@pytest.mark.skipif(not torch.cuda.is_available(), reason="no cuda device detected")
def test_consolidate_to_device(self):
td = TensorDict(
Expand Down
Loading