Skip to content

Commit

Permalink
Apply input transformation to eliminate recompilations (#2645)
Browse files Browse the repository at this point in the history
Summary:

KJT has

Differential Revision: D66976511
  • Loading branch information
Microve authored and facebook-github-bot committed Dec 18, 2024
1 parent 92b903f commit c139171
Showing 1 changed file with 111 additions and 3 deletions.
114 changes: 111 additions & 3 deletions torchrec/pt2/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@


import functools
from typing import Any, Callable
from typing import Any, Callable, Dict, List, Optional, Tuple

import torch
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor

"""
Prepares KJT for PT2 tracing.
Expand All @@ -28,6 +28,7 @@
def kjt_for_pt2_tracing(
kjt: KeyedJaggedTensor,
convert_to_vb: bool = False,
mark_length: bool = False,
) -> KeyedJaggedTensor:
# Breaking dependency cycle
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
Expand Down Expand Up @@ -78,25 +79,132 @@ def kjt_for_pt2_tracing(
weights = kjt.weights_or_none()
if weights is not None:
torch._dynamo.decorators.mark_unbacked(weights, 0)
if mark_length:
torch._dynamo.decorators.mark_unbacked(lengths, 0)

return KeyedJaggedTensor(
length_per_key_marked_dynamic = []

for length in kjt.length_per_key():
length_per_key_marked_dynamic.append(length)

return PT2KeyedJaggedTensor(
keys=kjt.keys(),
values=values,
lengths=lengths,
weights=weights,
stride=stride if not is_vb else None,
stride_per_key_per_rank=kjt.stride_per_key_per_rank() if is_vb else None,
inverse_indices=inverse_indices,
length_per_key=(length_per_key_marked_dynamic if is_vb else None),
)


class PT2KeyedJaggedTensor(KeyedJaggedTensor):
"""
This subclass of KeyedJaggedTensor is used to support PT2 tracing.
We can apply some modifications to make KJT friendly for PT2 tracing.
"""

def __init__(
self,
keys: List[str],
values: torch.Tensor,
weights: Optional[torch.Tensor] = None,
lengths: Optional[torch.Tensor] = None,
offsets: Optional[torch.Tensor] = None,
stride: Optional[int] = None,
stride_per_key_per_rank: Optional[List[List[int]]] = None,
stride_per_key: Optional[List[int]] = None,
length_per_key: Optional[List[int]] = None,
lengths_offset_per_key: Optional[List[int]] = None,
offset_per_key: Optional[List[int]] = None,
index_per_key: Optional[Dict[str, int]] = None,
jt_dict: Optional[Dict[str, JaggedTensor]] = None,
inverse_indices: Optional[Tuple[List[str], torch.Tensor]] = None,
) -> None:
super().__init__(
keys=keys,
values=values,
weights=weights,
lengths=lengths,
offsets=offsets,
stride=stride,
stride_per_key_per_rank=stride_per_key_per_rank,
stride_per_key=stride_per_key,
length_per_key=None,
lengths_offset_per_key=lengths_offset_per_key,
offset_per_key=offset_per_key,
index_per_key=index_per_key,
jt_dict=jt_dict,
inverse_indices=inverse_indices,
)
self.length_per_key_tensors: List[torch.Tensor] = []
for length in length_per_key or []:
# dynamo does not support directly mark integers as dynamic, we thus apply a trick to embed the integer into a tensor's size and mark the size as dynamic
t = torch.empty((length, 0))
torch._dynamo.mark_dynamic(t, 0)
self.length_per_key_tensors.append(t)

self.stride_per_key_per_rank_tensor: List[List[torch.Tensor]] = []
for strides_per_key in stride_per_key_per_rank or []:
strides_per_key_list: List[torch.Tensor] = []
for s in strides_per_key:
t = torch.empty((s, 0))
torch._dynamo.mark_dynamic(t, 0)
strides_per_key_list.append(t)
self.stride_per_key_per_rank_tensor.append(strides_per_key_list)

def length_per_key(self) -> List[int]:
if len(self.length_per_key_tensors) > 0:
# since size has been marked as dynamic, we get a list of dynamic integers
self._length_per_key = [t.size(0) for t in self.length_per_key_tensors]
else:
self._length_per_key = super().length_per_key()
return self._length_per_key

def stride_per_key_per_rank(self) -> List[List[int]]:
if len(self.stride_per_key_per_rank_tensor) > 0:
self._stride_per_key_per_rank = [
[t.size(0) for t in strides_per_key_list]
for strides_per_key_list in self.stride_per_key_per_rank_tensor
]
else:
self._stride_per_key_per_rank = super().stride_per_key_per_rank()
return self._stride_per_key_per_rank


# pyre-ignore
def default_pipeline_input_transformer(inp):
# different input items need different handlings
for attr_name in ["id_list_features", "id_score_list_features"]:
if hasattr(inp, attr_name):
attr = getattr(inp, attr_name)
if isinstance(attr, KeyedJaggedTensor):
setattr(inp, attr_name, kjt_for_pt2_tracing(attr))
for attr_name in [
"uhm_history_timestamps",
"raw_uhm_history_timestamps",
"event_id_list_feature_invert_indexes",
]:
if hasattr(inp, attr_name):
attr = getattr(inp, attr_name)
if isinstance(attr, dict):
for key in attr:
torch._dynamo.decorators.mark_dynamic(attr[key], 0)
if hasattr(inp, "supervision_label"):
torch._dynamo.decorators.mark_dynamic(inp.supervision_label["keys"], 0)
torch._dynamo.decorators.mark_dynamic(inp.supervision_label["values"], 0)

for attr_name in ["event_id_list_features_seqs"]:
if hasattr(inp, attr_name):
attr = getattr(inp, attr_name)
if isinstance(attr, dict):
for key in attr:
if isinstance(attr[key], KeyedJaggedTensor):
attr[key] = kjt_for_pt2_tracing(attr[key], mark_length=True)

setattr(inp, attr_name, attr)

return inp


Expand Down

0 comments on commit c139171

Please sign in to comment.