Skip to content

Commit

Permalink
support meta device compatability (#1740)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1740

Adding meta device support to KeyedJaggedTensor to make Ads APF models compatible with meta device. This allows calculating FLOPs for models locally without allocating memory on cpu/gpu.

This diff makes the following changes:
* passes `values` tensor to  `_maybe_compute_length_per_key`
* when on meta device, creates dummy values for `length_per_key` that still sum to the size of `values` tensor so that downstream torch operators are valid.

The changes resolve the following error when running model forward with meta tensors (full trace P1191019670)
```
    torch.sum(lengths.view(-1, stride), dim=1).tolist()
NotImplementedError: Cannot copy out of meta tensor; no data!
```

## Additional Context
Loading the model on meta device allows us to load large FM models locally since meta tensors don't have values. See snippet below of a meta tensor which doesn't have values but contains valid `size` attribute which are used for FLOPs calculations:
```
torch.tensor([1, 2, 3], device=torch.device("meta"))
>>> tensor(..., device='meta', size=(3,), dtype=torch.int64)
```
By creating dummy values for `length_per_key` that sum to the total length of the `values` tensor we make KJT operations valid. Furthermore, model embedding ops care about the relationship between the shape of indices and the values of lengths, but since there is no values in meta tensor, the shape stops mattering. The total number of operations remain the same as total length of `values` tensor is preserved. The test plan shows FLOPs calculations remain the same if model were loaded on cpu vs. meta device.

See design doc for more detailed context: https://docs.google.com/document/d/1DteShg9A8Nts3OTu2SrSxMp-0W2spZ9-v1br73z0LIQ/edit

Reviewed By: joshuadeng

Differential Revision: D54403867

fbshipit-source-id: a09ad16e269d241f66b2cba3377d878db65b581d
  • Loading branch information
edqwerty10 authored and facebook-github-bot committed Mar 4, 2024
1 parent a36195e commit 4e518b1
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 1 deletion.
12 changes: 11 additions & 1 deletion torchrec/sparse/jagged_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,9 +756,15 @@ def _maybe_compute_length_per_key(
length_per_key: Optional[List[int]],
lengths: Optional[torch.Tensor],
offsets: Optional[torch.Tensor],
values: Optional[torch.Tensor],
) -> List[int]:
if length_per_key is None:
if len(keys) and offsets is not None and len(offsets) > 0:
if values is not None and values.is_meta:
# create dummy lengths per key when on meta device
total_length = values.numel()
_length = [total_length // len(keys)] * len(keys)
_length[0] += total_length % len(keys)
elif len(keys) and offsets is not None and len(offsets) > 0:
_length: List[int] = (
_length_per_key_from_stride_per_key(torch.diff(offsets), stride_per_key)
if variable_stride_per_key
Expand Down Expand Up @@ -789,6 +795,7 @@ def _maybe_compute_offset_per_key(
offset_per_key: Optional[List[int]],
lengths: Optional[torch.Tensor],
offsets: Optional[torch.Tensor],
values: Optional[torch.Tensor],
) -> Tuple[List[int], List[int]]:
if length_per_key is None:
_length_per_key: List[int] = _maybe_compute_length_per_key(
Expand All @@ -799,6 +806,7 @@ def _maybe_compute_offset_per_key(
length_per_key=length_per_key,
lengths=lengths,
offsets=offsets,
values=values,
)
return _length_per_key, _cumsum(_length_per_key)
elif offset_per_key is None:
Expand Down Expand Up @@ -1562,6 +1570,7 @@ def length_per_key(self) -> List[int]:
length_per_key=self._length_per_key,
lengths=self._lengths,
offsets=self._offsets,
values=self._values,
)
self._length_per_key = _length_per_key
return _length_per_key
Expand All @@ -1579,6 +1588,7 @@ def offset_per_key(self) -> List[int]:
offset_per_key=self._offset_per_key,
lengths=self._lengths,
offsets=self._offsets,
values=self._values,
)
self._length_per_key = _length_per_key
self._offset_per_key = _offset_per_key
Expand Down
39 changes: 39 additions & 0 deletions torchrec/sparse/tests/test_jagged_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1795,6 +1795,45 @@ def test_equality(self) -> None:
non_kjt_input = "not a KeyedJaggedTensor instance"
self.assertFalse(kjt_is_equal(kt, non_kjt_input))

def test_meta_device_compatibility(self) -> None:
keys = ["index_0", "index_1", "index_2", "index_3"]
lengths = torch.tensor(
[2, 0, 1, 1, 1, 3, 0, 2],
device=torch.device("meta"),
)
offsets = torch.tensor(
[0, 2, 2, 3, 4, 5, 8, 8, 10],
device=torch.device("meta"),
)
values = torch.tensor(
[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0],
device=torch.device("meta"),
)
weights = torch.tensor(
[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0],
device=torch.device("meta"),
)
kjt = KeyedJaggedTensor(
keys=keys,
values=values,
weights=weights,
lengths=lengths,
)

kjt.sync()
kjt.unsync()

jt_dict = kjt.to_dict()
kjt = KeyedJaggedTensor.from_jt_dict(jt_dict)

kjt = KeyedJaggedTensor.from_lengths_sync(
keys=keys, values=values, weights=weights, lengths=lengths
)

kjt = KeyedJaggedTensor.from_offsets_sync(
keys=keys, values=values, weights=weights, offsets=offsets
)


class TestKeyedJaggedTensorScripting(unittest.TestCase):
def test_scriptable_forward(self) -> None:
Expand Down

0 comments on commit 4e518b1

Please sign in to comment.