Skip to content

Commit

Permalink
Faster KJT init (pytorch#2231)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#2231

To improve inference, we want to make creating a KJT as cheap as possible, which means the init method is nothing more than a attribute setter.  All other fields are calculated lazily.  This is practicularly important wrt jit script and moving between compilation units.

Reviewed By: joshuadeng

Differential Revision: D59765149

fbshipit-source-id: af097d3d5b79cc2f264e7a75efb6774b44f574c3
  • Loading branch information
dstaay-fb authored and facebook-github-bot committed Sep 3, 2024
1 parent 2a5df95 commit c43184e
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 49 deletions.
180 changes: 134 additions & 46 deletions torchrec/sparse/jagged_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,22 @@ def _get_weights_or_throw(weights: Optional[torch.Tensor]) -> torch.Tensor:
return weights


def _get_lengths_offset_per_key_or_throw(
lengths_offset_per_key: Optional[List[int]],
) -> List[int]:
assert (
lengths_offset_per_key is not None
), "This (Keyed)JaggedTensor doesn't have lengths_offset_per_key."
return lengths_offset_per_key


def _get_stride_per_key_or_throw(stride_per_key: Optional[List[int]]) -> List[int]:
assert (
stride_per_key is not None
), "This (Keyed)JaggedTensor doesn't have stride_per_key."
return stride_per_key


def _get_inverse_indices_or_throw(
inverse_indices: Optional[Tuple[List[str], torch.Tensor]],
) -> Tuple[List[str], torch.Tensor]:
Expand Down Expand Up @@ -885,9 +901,9 @@ def _jt_flatten_spec(t: JaggedTensor, spec: TreeSpec) -> List[Optional[torch.Ten


def _assert_tensor_has_no_elements_or_has_integers(
tensor: torch.Tensor, tensor_name: str
tensor: Optional[torch.Tensor], tensor_name: str
) -> None:
if is_torchdynamo_compiling():
if is_torchdynamo_compiling() or tensor is None:
# Skipping the check tensor.numel() == 0 to not guard on pt2 symbolic shapes.
# TODO(ivankobzarev): Use guard_size_oblivious to pass tensor.numel() == 0 once it is torch scriptable.
return
Expand Down Expand Up @@ -915,10 +931,13 @@ def _maybe_compute_stride_kjt(
stride: Optional[int],
lengths: Optional[torch.Tensor],
offsets: Optional[torch.Tensor],
stride_per_key_per_rank: Optional[List[List[int]]],
) -> int:
if stride is None:
if len(keys) == 0:
stride = 0
elif stride_per_key_per_rank is not None and len(stride_per_key_per_rank) > 0:
stride = max([sum(s) for s in stride_per_key_per_rank])
elif offsets is not None and offsets.numel() > 0:
stride = (offsets.numel() - 1) // len(keys)
elif lengths is not None:
Expand Down Expand Up @@ -1461,6 +1480,38 @@ def _check_attributes(
return True


def _maybe_compute_lengths_offset_per_key(
lengths_offset_per_key: Optional[List[int]],
stride_per_key: Optional[List[int]],
stride: Optional[int],
keys: List[str],
) -> Optional[List[int]]:
if lengths_offset_per_key is not None:
return lengths_offset_per_key
elif stride_per_key is not None:
return _cumsum(stride_per_key)
elif stride is not None:
return _cumsum([stride] * len(keys))
else:
return None


def _maybe_compute_stride_per_key(
stride_per_key: Optional[List[int]],
stride_per_key_per_rank: Optional[List[List[int]]],
stride: Optional[int],
keys: List[str],
) -> Optional[List[int]]:
if stride_per_key is not None:
return stride_per_key
elif stride_per_key_per_rank is not None:
return [sum(s) for s in stride_per_key_per_rank]
elif stride is not None:
return [stride] * len(keys)
else:
return None


class KeyedJaggedTensor(Pipelineable, metaclass=JaggedTensorMeta):
"""Represents an (optionally weighted) keyed jagged tensor.
Expand Down Expand Up @@ -1534,62 +1585,53 @@ def __init__(
stride: Optional[int] = None,
stride_per_key_per_rank: Optional[List[List[int]]] = None,
# Below exposed to ensure torch.script-able
stride_per_key: Optional[List[int]] = None,
length_per_key: Optional[List[int]] = None,
lengths_offset_per_key: Optional[List[int]] = None,
offset_per_key: Optional[List[int]] = None,
index_per_key: Optional[Dict[str, int]] = None,
jt_dict: Optional[Dict[str, JaggedTensor]] = None,
inverse_indices: Optional[Tuple[List[str], torch.Tensor]] = None,
) -> None:
"""
This is the constructor for KeyedJaggedTensor is jit.scriptable and PT2 compatible.
It is important only to assign attributes here or do input checks to support various
internal inference optimizations. By convention the attirbute is named same as input arg, just
with leading underscore
"""
self._keys: List[str] = keys
self._values: torch.Tensor = values
self._weights: Optional[torch.Tensor] = weights
if offsets is not None:
_assert_tensor_has_no_elements_or_has_integers(offsets, "offsets")
if lengths is not None:
_assert_tensor_has_no_elements_or_has_integers(lengths, "lengths")
self._lengths: Optional[torch.Tensor] = lengths
self._offsets: Optional[torch.Tensor] = offsets

self._stride_per_key_per_rank: List[List[int]] = []
self._stride_per_key: List[int] = []
self._variable_stride_per_key: bool = False
self._stride: int = -1

if stride_per_key_per_rank is not None:
self._stride_per_key_per_rank = stride_per_key_per_rank
self._stride_per_key = [sum(s) for s in self._stride_per_key_per_rank]
self._variable_stride_per_key = True
if stride is not None:
self._stride = stride
else:
self._stride = (
max(self._stride_per_key) if len(self._stride_per_key) > 0 else 0
)
else:
stride = _maybe_compute_stride_kjt(keys, stride, lengths, offsets)
self._stride = stride
self._stride_per_key_per_rank = [[stride]] * len(self._keys)
self._stride_per_key = [sum(s) for s in self._stride_per_key_per_rank]

# lazy fields
self._stride: Optional[int] = stride
self._stride_per_key_per_rank: Optional[List[List[int]]] = (
stride_per_key_per_rank
)
self._stride_per_key: Optional[List[int]] = stride_per_key
self._length_per_key: Optional[List[int]] = length_per_key
self._offset_per_key: Optional[List[int]] = offset_per_key
self._lengths_offset_per_key: Optional[List[int]] = lengths_offset_per_key
self._index_per_key: Optional[Dict[str, int]] = index_per_key
self._jt_dict: Optional[Dict[str, JaggedTensor]] = jt_dict
self._inverse_indices: Optional[Tuple[List[str], torch.Tensor]] = (
inverse_indices
)
self._lengths_offset_per_key: List[int] = []

self._init_pt2_checks()
# validation logic
if not torch.jit.is_scripting():
_assert_tensor_has_no_elements_or_has_integers(offsets, "offsets")
_assert_tensor_has_no_elements_or_has_integers(lengths, "lengths")
self._init_pt2_checks()

def _init_pt2_checks(self) -> None:
if torch.jit.is_scripting() or not is_torchdynamo_compiling():
return

pt2_checks_all_is_size(self._stride_per_key)
for s in self._stride_per_key_per_rank:
pt2_checks_all_is_size(s)
if self._stride_per_key is not None:
pt2_checks_all_is_size(self._stride_per_key)
if self._stride_per_key_per_rank is not None:
# pyre-ignore [16]
for s in self._stride_per_key_per_rank:
pt2_checks_all_is_size(s)

@staticmethod
def from_offsets_sync(
Expand Down Expand Up @@ -1839,16 +1881,32 @@ def weights_or_none(self) -> Optional[torch.Tensor]:
return self._weights

def stride(self) -> int:
return self._stride
stride = _maybe_compute_stride_kjt(
self._keys,
self._stride,
self._lengths,
self._offsets,
self._stride_per_key_per_rank,
)
self._stride = stride
return stride

def stride_per_key(self) -> List[int]:
return self._stride_per_key
stride_per_key = _maybe_compute_stride_per_key(
self._stride_per_key,
self._stride_per_key_per_rank,
self.stride(),
self._keys,
)
self._stride_per_key = stride_per_key
return _get_stride_per_key_or_throw(stride_per_key)

def stride_per_key_per_rank(self) -> List[List[int]]:
return self._stride_per_key_per_rank
stride_per_key_per_rank = self._stride_per_key_per_rank
return stride_per_key_per_rank if stride_per_key_per_rank is not None else []

def variable_stride_per_key(self) -> bool:
return self._variable_stride_per_key
return self._stride_per_key_per_rank is not None

def inverse_indices(self) -> Tuple[List[str], torch.Tensor]:
return _get_inverse_indices_or_throw(self._inverse_indices)
Expand Down Expand Up @@ -1901,9 +1959,20 @@ def offset_per_key_or_none(self) -> Optional[List[int]]:
return self._offset_per_key

def lengths_offset_per_key(self) -> List[int]:
if not self._lengths_offset_per_key:
self._lengths_offset_per_key = _cumsum(self.stride_per_key())
return self._lengths_offset_per_key
if self.variable_stride_per_key():
_lengths_offset_per_key = _maybe_compute_lengths_offset_per_key(
self._lengths_offset_per_key,
self.stride_per_key(),
None,
self._keys,
)
else:
_lengths_offset_per_key = _maybe_compute_lengths_offset_per_key(
self._lengths_offset_per_key, None, self.stride(), self._keys
)

self._lengths_offset_per_key = _lengths_offset_per_key
return _get_lengths_offset_per_key_or_throw(_lengths_offset_per_key)

def index_per_key(self) -> Dict[str, int]:
return self._key_indices()
Expand Down Expand Up @@ -1934,7 +2003,9 @@ def split(self, segments: List[int]) -> List["KeyedJaggedTensor"]:
offsets=self._offsets,
stride=self._stride,
stride_per_key_per_rank=stride_per_key_per_rank,
stride_per_key=None,
length_per_key=self._length_per_key,
lengths_offset_per_key=None,
offset_per_key=self._offset_per_key,
index_per_key=self._index_per_key,
jt_dict=self._jt_dict,
Expand Down Expand Up @@ -1968,7 +2039,9 @@ def split(self, segments: List[int]) -> List["KeyedJaggedTensor"]:
),
stride=self._stride,
stride_per_key_per_rank=stride_per_key_per_rank,
stride_per_key=None,
length_per_key=None,
lengths_offset_per_key=None,
offset_per_key=None,
index_per_key=None,
jt_dict=None,
Expand Down Expand Up @@ -2012,7 +2085,9 @@ def split(self, segments: List[int]) -> List["KeyedJaggedTensor"]:
offsets=None,
stride=self._stride,
stride_per_key_per_rank=stride_per_key_per_rank,
stride_per_key=None,
length_per_key=split_length_per_key,
lengths_offset_per_key=None,
offset_per_key=None,
index_per_key=None,
jt_dict=None,
Expand Down Expand Up @@ -2046,7 +2121,9 @@ def split(self, segments: List[int]) -> List["KeyedJaggedTensor"]:
offsets=None,
stride=self._stride,
stride_per_key_per_rank=stride_per_key_per_rank,
stride_per_key=None,
length_per_key=split_length_per_key,
lengths_offset_per_key=None,
offset_per_key=None,
index_per_key=None,
jt_dict=None,
Expand Down Expand Up @@ -2074,10 +2151,11 @@ def permute(
for index in indices:
key = self.keys()[index]
permuted_keys.append(key)
permuted_stride_per_key_per_rank.append(
self.stride_per_key_per_rank()[index]
)
permuted_length_per_key.append(length_per_key[index])
if self.variable_stride_per_key():
permuted_stride_per_key_per_rank.append(
self.stride_per_key_per_rank()[index]
)

permuted_length_per_key_sum = sum(permuted_length_per_key)
if not torch.jit.is_scripting() and is_non_strict_exporting():
Expand Down Expand Up @@ -2140,7 +2218,9 @@ def permute(
offsets=None,
stride=self._stride,
stride_per_key_per_rank=stride_per_key_per_rank,
stride_per_key=None,
length_per_key=permuted_length_per_key if len(permuted_keys) > 0 else None,
lengths_offset_per_key=None,
offset_per_key=None,
index_per_key=None,
jt_dict=None,
Expand All @@ -2160,7 +2240,9 @@ def flatten_lengths(self) -> "KeyedJaggedTensor":
offsets=None,
stride=self._stride,
stride_per_key_per_rank=stride_per_key_per_rank,
stride_per_key=None,
length_per_key=self.length_per_key(),
lengths_offset_per_key=None,
offset_per_key=None,
index_per_key=None,
jt_dict=None,
Expand Down Expand Up @@ -2280,8 +2362,10 @@ def to(
self._stride_per_key_per_rank if self.variable_stride_per_key() else None
)
length_per_key = self._length_per_key
lengths_offset_per_key = self._lengths_offset_per_key
offset_per_key = self._offset_per_key
index_per_key = self._index_per_key
stride_per_key = self._stride_per_key
jt_dict = self._jt_dict
inverse_indices = self._inverse_indices
if inverse_indices is not None:
Expand Down Expand Up @@ -2313,7 +2397,9 @@ def to(
),
stride=self._stride,
stride_per_key_per_rank=stride_per_key_per_rank,
stride_per_key=stride_per_key,
length_per_key=length_per_key,
lengths_offset_per_key=lengths_offset_per_key,
offset_per_key=offset_per_key,
index_per_key=index_per_key,
jt_dict=jt_dict,
Expand Down Expand Up @@ -2363,7 +2449,9 @@ def pin_memory(self) -> "KeyedJaggedTensor":
offsets=offsets.pin_memory() if offsets is not None else None,
stride=self._stride,
stride_per_key_per_rank=stride_per_key_per_rank,
stride_per_key=self._stride_per_key,
length_per_key=self._length_per_key,
lengths_offset_per_key=self._lengths_offset_per_key,
offset_per_key=self._offset_per_key,
index_per_key=self._index_per_key,
jt_dict=None,
Expand Down
6 changes: 3 additions & 3 deletions torchrec/sparse/tests/test_jagged_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2121,13 +2121,13 @@ def forward(
lengths=input.lengths(),
offsets=input.offsets(),
)
return output, output._stride
return output, output.stride()

# Case 3: KeyedJaggedTensor is used as both an input and an output of the root module.
m = ModuleUseKeyedJaggedTensorAsInputAndOutput()
gm = symbolic_trace(m)
FileCheck().check("KeyedJaggedTensor").check("keys()").check("values()").check(
"._stride"
"stride"
).run(gm.code)
input = KeyedJaggedTensor.from_offsets_sync(
values=torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]),
Expand Down Expand Up @@ -2185,7 +2185,7 @@ def forward(
lengths: torch.Tensor,
) -> Tuple[KeyedJaggedTensor, int]:
output = KeyedJaggedTensor(keys, values, weights, lengths)
return output, output._stride
return output, output.stride()

# Case 1: KeyedJaggedTensor is only used as an output of the root module.
m = ModuleUseKeyedJaggedTensorAsOutput()
Expand Down

0 comments on commit c43184e

Please sign in to comment.