Skip to content

Commit

Permalink
support stbe length rebatching and remove stbe output padding for MTIA
Browse files Browse the repository at this point in the history
Summary:
1. For rebatching stbe length without output, it must be 2d tensor in the shape of [F x B] and we can directly concat at dim1; we will use _get_unflattened_lengths as batch info rule hints;
2. For MTIA inference, if stbe is in remote, its output will be padded to max batch size, which will make split not work. In this case, we want to remove the padding and restore its original size.

Reviewed By: PaulZhang12

Differential Revision: D64914077

fbshipit-source-id: e386be5721dd79aefcf2ea8f12c8399b900aa395
  • Loading branch information
seanx92 authored and facebook-github-bot committed Oct 25, 2024
1 parent cd64b9d commit f606d5c
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 6 deletions.
33 changes: 28 additions & 5 deletions torchrec/modules/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,22 @@ def _fx_to_list(tensor: torch.Tensor) -> List[int]:
return tensor.long().tolist()


@torch.fx.wrap
def _get_unflattened_lengths(lengths: torch.Tensor, num_features: int) -> torch.Tensor:
"""
Unflatten lengths tensor from [F * B] to [F, B].
"""
return lengths.view(num_features, -1)


@torch.fx.wrap
def _slice_1d_tensor(tensor: torch.Tensor, start: int, end: int) -> torch.Tensor:
"""
Slice tensor.
"""
return tensor[start:end]


def extract_module_or_tensor_callable(
module_or_callable: Union[
Callable[[], torch.nn.Module],
Expand Down Expand Up @@ -292,20 +308,27 @@ def construct_jagged_tensors_inference(
need_indices: bool = False,
features_to_permute_indices: Optional[Dict[str, List[int]]] = None,
reverse_indices: Optional[torch.Tensor] = None,
remove_padding: bool = False,
) -> Dict[str, JaggedTensor]:
with record_function("## construct_jagged_tensors_inference ##"):
# [F * B] -> [F, B]
unflattened_lengths = _get_unflattened_lengths(lengths, len(embedding_names))

if reverse_indices is not None:
embeddings = torch.index_select(
embeddings, 0, reverse_indices.to(torch.int32)
)
elif remove_padding:
embeddings = _slice_1d_tensor(
embeddings, 0, unflattened_lengths.sum().item()
)

ret: Dict[str, JaggedTensor] = {}
length_per_key: List[int] = _fx_to_list(
torch.sum(lengths.view(len(embedding_names), -1), dim=1)
)

lengths = lengths.view(len(embedding_names), -1)
lengths_tuple = torch.unbind(lengths, dim=0)
length_per_key: List[int] = _fx_to_list(torch.sum(unflattened_lengths, dim=1))

lengths_tuple = torch.unbind(unflattened_lengths, dim=0)

embeddings_list = torch.split(embeddings, length_per_key, dim=0)
values_list = torch.split(values, length_per_key) if need_indices else None

Expand Down
10 changes: 9 additions & 1 deletion torchrec/quant/embedding_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@
"__emb_name_to_num_rows_post_pruning"
)

MODULE_ATTR_REMOVE_STBE_PADDING_BOOL: str = "__remove_stbe_padding"

MODULE_ATTR_USE_BATCHING_HINTED_OUTPUT_BOOL: str = "__use_batching_hinted_output"

DEFAULT_ROW_ALIGNMENT = 16


Expand Down Expand Up @@ -894,14 +898,18 @@ def forward(
if self.register_tbes
else emb_module.forward(indices=indices, offsets=offsets)
)
lookup = _get_batching_hinted_output(lengths=lengths, output=lookup)
if getattr(self, MODULE_ATTR_USE_BATCHING_HINTED_OUTPUT_BOOL, True):
lookup = _get_batching_hinted_output(lengths=lengths, output=lookup)
embedding_names = self._embedding_names_by_batched_tables[key]
jt = construct_jagged_tensors_inference(
embeddings=lookup,
lengths=lengths,
values=indices,
embedding_names=embedding_names,
need_indices=self.need_indices(),
remove_padding=getattr(
self, MODULE_ATTR_REMOVE_STBE_PADDING_BOOL, False
),
)
for embedding_name in embedding_names:
feature_embeddings[embedding_name] = jt[embedding_name]
Expand Down

0 comments on commit f606d5c

Please sign in to comment.