Skip to content

Commit

Permalink
KJT.empty_like FX Traceable + Scriptable (#2187)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2187

Make KJT empty_like fx + scriptable

Reviewed By: iamzainhuda

Differential Revision: D59119165

fbshipit-source-id: c92e72531b45c9a86263377460577ba4008a7c09
  • Loading branch information
PaulZhang12 authored and facebook-github-bot committed Jun 28, 2024
1 parent 2a0c2ed commit 2f56549
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 17 deletions.
54 changes: 37 additions & 17 deletions torchrec/sparse/jagged_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1144,6 +1144,42 @@ def _merge_weights_or_none(
return torch.cat([a_weights, b_weights], dim=0)


@torch.fx.wrap
def _strides_from_kjt(
kjt: "KeyedJaggedTensor",
) -> Tuple[Optional[int], Optional[List[List[int]]]]:
stride, stride_per_key_per_rank = (
(None, kjt.stride_per_key_per_rank())
if kjt.variable_stride_per_key()
else (kjt.stride(), None)
)

return stride, stride_per_key_per_rank


@torch.fx.wrap
def _kjt_empty_like(kjt: "KeyedJaggedTensor") -> "KeyedJaggedTensor":
# empty like function fx wrapped, also avoids device hardcoding
stride, stride_per_key_per_rank = (
(None, kjt.stride_per_key_per_rank())
if kjt.variable_stride_per_key()
else (kjt.stride(), None)
)

return KeyedJaggedTensor(
keys=[],
values=torch.empty(0, device=kjt.device(), dtype=kjt.values().dtype),
weights=(
None
if kjt.weights_or_none() is None
else torch.empty(0, device=kjt.device(), dtype=kjt.weights().dtype)
),
lengths=torch.empty(0, device=kjt.device(), dtype=kjt.lengths().dtype),
stride=stride,
stride_per_key_per_rank=stride_per_key_per_rank,
)


def _sum_by_splits(input_list: List[int], splits: List[int]) -> List[int]:
return [
sum(input_list[sum(splits[:i]) : sum(splits[:i]) + n])
Expand Down Expand Up @@ -1549,23 +1585,7 @@ def empty(

@staticmethod
def empty_like(kjt: "KeyedJaggedTensor") -> "KeyedJaggedTensor":
stride, stride_per_key_per_rank = (
(None, kjt.stride_per_key_per_rank())
if kjt.variable_stride_per_key()
else (kjt.stride(), None)
)
return KeyedJaggedTensor(
keys=[],
values=torch.empty(0, device=kjt.device(), dtype=kjt.values().dtype),
weights=(
None
if kjt.weights_or_none() is None
else torch.empty(0, device=kjt.device(), dtype=kjt.weights().dtype)
),
lengths=torch.empty(0, device=kjt.device(), dtype=kjt.lengths().dtype),
stride=stride,
stride_per_key_per_rank=stride_per_key_per_rank,
)
return _kjt_empty_like(kjt)

@staticmethod
def from_jt_dict(jt_dict: Dict[str, JaggedTensor]) -> "KeyedJaggedTensor":
Expand Down
30 changes: 30 additions & 0 deletions torchrec/sparse/tests/test_jagged_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1999,6 +1999,36 @@ def forward(self, input: int) -> int:
self.assertEqual(ref_out, traced_out)
torch.jit.script(gm)

def test_traceable_empty_like(self) -> None:
class ModuleCreateAndAccessEmptyLikeKeyedJaggedTensor(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, kjt: KeyedJaggedTensor) -> int:
features = KeyedJaggedTensor.empty_like(kjt)
return (
len(features.keys())
+ features.values().numel()
+ features.weights().numel()
+ features.lengths().numel()
+ features.offsets().numel()
)

# Case 4: KeyedJaggedTensor is only used within the root module and not as part of
# the root module's input/output interface.
m = ModuleCreateAndAccessEmptyLikeKeyedJaggedTensor()
kjt = KeyedJaggedTensor.from_offsets_sync(
values=torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]),
weights=torch.Tensor([1.0, 0.5, 1.5, 1.0, 0.5, 1.0, 1.0, 1.5]),
keys=["index_0", "index_1"],
offsets=torch.IntTensor([0, 0, 2, 2, 3, 4, 5, 5, 8]),
)
gm = symbolic_trace(m)
ref_out = m(kjt)
traced_out = gm(kjt)
self.assertEqual(ref_out, traced_out)
torch.jit.script(gm)

def test_use_keyed_jagged_tensor_as_input_and_output(self) -> None:
class ModuleUseKeyedJaggedTensorAsInputAndOutput(torch.nn.Module):
def __init__(self):
Expand Down

0 comments on commit 2f56549

Please sign in to comment.