diff --git a/torchrec/sparse/jagged_tensor.py b/torchrec/sparse/jagged_tensor.py index f703fe8ec..b5fe73002 100644 --- a/torchrec/sparse/jagged_tensor.py +++ b/torchrec/sparse/jagged_tensor.py @@ -1144,6 +1144,42 @@ def _merge_weights_or_none( return torch.cat([a_weights, b_weights], dim=0) +@torch.fx.wrap +def _strides_from_kjt( + kjt: "KeyedJaggedTensor", +) -> Tuple[Optional[int], Optional[List[List[int]]]]: + stride, stride_per_key_per_rank = ( + (None, kjt.stride_per_key_per_rank()) + if kjt.variable_stride_per_key() + else (kjt.stride(), None) + ) + + return stride, stride_per_key_per_rank + + +@torch.fx.wrap +def _kjt_empty_like(kjt: "KeyedJaggedTensor") -> "KeyedJaggedTensor": + # empty like function fx wrapped, also avoids device hardcoding + stride, stride_per_key_per_rank = ( + (None, kjt.stride_per_key_per_rank()) + if kjt.variable_stride_per_key() + else (kjt.stride(), None) + ) + + return KeyedJaggedTensor( + keys=[], + values=torch.empty(0, device=kjt.device(), dtype=kjt.values().dtype), + weights=( + None + if kjt.weights_or_none() is None + else torch.empty(0, device=kjt.device(), dtype=kjt.weights().dtype) + ), + lengths=torch.empty(0, device=kjt.device(), dtype=kjt.lengths().dtype), + stride=stride, + stride_per_key_per_rank=stride_per_key_per_rank, + ) + + def _sum_by_splits(input_list: List[int], splits: List[int]) -> List[int]: return [ sum(input_list[sum(splits[:i]) : sum(splits[:i]) + n]) @@ -1549,23 +1585,7 @@ def empty( @staticmethod def empty_like(kjt: "KeyedJaggedTensor") -> "KeyedJaggedTensor": - stride, stride_per_key_per_rank = ( - (None, kjt.stride_per_key_per_rank()) - if kjt.variable_stride_per_key() - else (kjt.stride(), None) - ) - return KeyedJaggedTensor( - keys=[], - values=torch.empty(0, device=kjt.device(), dtype=kjt.values().dtype), - weights=( - None - if kjt.weights_or_none() is None - else torch.empty(0, device=kjt.device(), dtype=kjt.weights().dtype) - ), - lengths=torch.empty(0, device=kjt.device(), dtype=kjt.lengths().dtype), - stride=stride, - stride_per_key_per_rank=stride_per_key_per_rank, - ) + return _kjt_empty_like(kjt) @staticmethod def from_jt_dict(jt_dict: Dict[str, JaggedTensor]) -> "KeyedJaggedTensor": diff --git a/torchrec/sparse/tests/test_jagged_tensor.py b/torchrec/sparse/tests/test_jagged_tensor.py index 9efeb444c..566fcfef7 100644 --- a/torchrec/sparse/tests/test_jagged_tensor.py +++ b/torchrec/sparse/tests/test_jagged_tensor.py @@ -1999,6 +1999,36 @@ def forward(self, input: int) -> int: self.assertEqual(ref_out, traced_out) torch.jit.script(gm) + def test_traceable_empty_like(self) -> None: + class ModuleCreateAndAccessEmptyLikeKeyedJaggedTensor(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, kjt: KeyedJaggedTensor) -> int: + features = KeyedJaggedTensor.empty_like(kjt) + return ( + len(features.keys()) + + features.values().numel() + + features.weights().numel() + + features.lengths().numel() + + features.offsets().numel() + ) + + # Case 4: KeyedJaggedTensor is only used within the root module and not as part of + # the root module's input/output interface. + m = ModuleCreateAndAccessEmptyLikeKeyedJaggedTensor() + kjt = KeyedJaggedTensor.from_offsets_sync( + values=torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]), + weights=torch.Tensor([1.0, 0.5, 1.5, 1.0, 0.5, 1.0, 1.0, 1.5]), + keys=["index_0", "index_1"], + offsets=torch.IntTensor([0, 0, 2, 2, 3, 4, 5, 5, 8]), + ) + gm = symbolic_trace(m) + ref_out = m(kjt) + traced_out = gm(kjt) + self.assertEqual(ref_out, traced_out) + torch.jit.script(gm) + def test_use_keyed_jagged_tensor_as_input_and_output(self) -> None: class ModuleUseKeyedJaggedTensorAsInputAndOutput(torch.nn.Module): def __init__(self):