From 4e518b154129309cb4b0495652c7aa8d67e04e2f Mon Sep 17 00:00:00 2001 From: Edson Romero Date: Mon, 4 Mar 2024 11:24:19 -0800 Subject: [PATCH] support meta device compatability (#1740) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/1740 Adding meta device support to KeyedJaggedTensor to make Ads APF models compatible with meta device. This allows calculating FLOPs for models locally without allocating memory on cpu/gpu. This diff makes the following changes: * passes `values` tensor to `_maybe_compute_length_per_key` * when on meta device, creates dummy values for `length_per_key` that still sum to the size of `values` tensor so that downstream torch operators are valid. The changes resolve the following error when running model forward with meta tensors (full trace P1191019670) ``` torch.sum(lengths.view(-1, stride), dim=1).tolist() NotImplementedError: Cannot copy out of meta tensor; no data! ``` ## Additional Context Loading the model on meta device allows us to load large FM models locally since meta tensors don't have values. See snippet below of a meta tensor which doesn't have values but contains valid `size` attribute which are used for FLOPs calculations: ``` torch.tensor([1, 2, 3], device=torch.device("meta")) >>> tensor(..., device='meta', size=(3,), dtype=torch.int64) ``` By creating dummy values for `length_per_key` that sum to the total length of the `values` tensor we make KJT operations valid. Furthermore, model embedding ops care about the relationship between the shape of indices and the values of lengths, but since there is no values in meta tensor, the shape stops mattering. The total number of operations remain the same as total length of `values` tensor is preserved. The test plan shows FLOPs calculations remain the same if model were loaded on cpu vs. meta device. See design doc for more detailed context: https://docs.google.com/document/d/1DteShg9A8Nts3OTu2SrSxMp-0W2spZ9-v1br73z0LIQ/edit Reviewed By: joshuadeng Differential Revision: D54403867 fbshipit-source-id: a09ad16e269d241f66b2cba3377d878db65b581d --- torchrec/sparse/jagged_tensor.py | 12 ++++++- torchrec/sparse/tests/test_jagged_tensor.py | 39 +++++++++++++++++++++ 2 files changed, 50 insertions(+), 1 deletion(-) diff --git a/torchrec/sparse/jagged_tensor.py b/torchrec/sparse/jagged_tensor.py index 2bff0f3d8..b23943df5 100644 --- a/torchrec/sparse/jagged_tensor.py +++ b/torchrec/sparse/jagged_tensor.py @@ -756,9 +756,15 @@ def _maybe_compute_length_per_key( length_per_key: Optional[List[int]], lengths: Optional[torch.Tensor], offsets: Optional[torch.Tensor], + values: Optional[torch.Tensor], ) -> List[int]: if length_per_key is None: - if len(keys) and offsets is not None and len(offsets) > 0: + if values is not None and values.is_meta: + # create dummy lengths per key when on meta device + total_length = values.numel() + _length = [total_length // len(keys)] * len(keys) + _length[0] += total_length % len(keys) + elif len(keys) and offsets is not None and len(offsets) > 0: _length: List[int] = ( _length_per_key_from_stride_per_key(torch.diff(offsets), stride_per_key) if variable_stride_per_key @@ -789,6 +795,7 @@ def _maybe_compute_offset_per_key( offset_per_key: Optional[List[int]], lengths: Optional[torch.Tensor], offsets: Optional[torch.Tensor], + values: Optional[torch.Tensor], ) -> Tuple[List[int], List[int]]: if length_per_key is None: _length_per_key: List[int] = _maybe_compute_length_per_key( @@ -799,6 +806,7 @@ def _maybe_compute_offset_per_key( length_per_key=length_per_key, lengths=lengths, offsets=offsets, + values=values, ) return _length_per_key, _cumsum(_length_per_key) elif offset_per_key is None: @@ -1562,6 +1570,7 @@ def length_per_key(self) -> List[int]: length_per_key=self._length_per_key, lengths=self._lengths, offsets=self._offsets, + values=self._values, ) self._length_per_key = _length_per_key return _length_per_key @@ -1579,6 +1588,7 @@ def offset_per_key(self) -> List[int]: offset_per_key=self._offset_per_key, lengths=self._lengths, offsets=self._offsets, + values=self._values, ) self._length_per_key = _length_per_key self._offset_per_key = _offset_per_key diff --git a/torchrec/sparse/tests/test_jagged_tensor.py b/torchrec/sparse/tests/test_jagged_tensor.py index 869c0f8d4..d34662e08 100644 --- a/torchrec/sparse/tests/test_jagged_tensor.py +++ b/torchrec/sparse/tests/test_jagged_tensor.py @@ -1795,6 +1795,45 @@ def test_equality(self) -> None: non_kjt_input = "not a KeyedJaggedTensor instance" self.assertFalse(kjt_is_equal(kt, non_kjt_input)) + def test_meta_device_compatibility(self) -> None: + keys = ["index_0", "index_1", "index_2", "index_3"] + lengths = torch.tensor( + [2, 0, 1, 1, 1, 3, 0, 2], + device=torch.device("meta"), + ) + offsets = torch.tensor( + [0, 2, 2, 3, 4, 5, 8, 8, 10], + device=torch.device("meta"), + ) + values = torch.tensor( + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], + device=torch.device("meta"), + ) + weights = torch.tensor( + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], + device=torch.device("meta"), + ) + kjt = KeyedJaggedTensor( + keys=keys, + values=values, + weights=weights, + lengths=lengths, + ) + + kjt.sync() + kjt.unsync() + + jt_dict = kjt.to_dict() + kjt = KeyedJaggedTensor.from_jt_dict(jt_dict) + + kjt = KeyedJaggedTensor.from_lengths_sync( + keys=keys, values=values, weights=weights, lengths=lengths + ) + + kjt = KeyedJaggedTensor.from_offsets_sync( + keys=keys, values=values, weights=weights, offsets=offsets + ) + class TestKeyedJaggedTensorScripting(unittest.TestCase): def test_scriptable_forward(self) -> None: