Skip to content

Commit

Permalink
add option for MODULE_ATTR_USE_BATCHING_HINTED_OUTPUT (#2544)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2544

The current GPU rebatching logics for _get_batching_hinted_output does not work for pooling_factor > 1, we need to disable it for GPU model to avoid rebatching flattened feature length

Reviewed By: PaulZhang12

Differential Revision: D65499768

fbshipit-source-id: a450efa87d9e2b3e7cab4a3fcdbd94f99214e4cd
  • Loading branch information
seanx92 authored and facebook-github-bot committed Nov 6, 2024
1 parent 509b0d2 commit 42c512c
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion torchrec/quant/embedding_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@
"__use_unflattened_lengths_for_batching"
)

MODULE_ATTR_USE_BATCHING_HINTED_OUTPUT: str = "__use_batching_hinted_output"

DEFAULT_ROW_ALIGNMENT = 16


Expand Down Expand Up @@ -913,7 +915,8 @@ def forward(
lengths = _get_unflattened_lengths(lengths, len(embedding_names))
lookup = _get_batching_hinted_output(lengths=lengths, output=lookup)
else:
lookup = _get_batching_hinted_output(lengths=lengths, output=lookup)
if getattr(self, MODULE_ATTR_USE_BATCHING_HINTED_OUTPUT, True):
lookup = _get_batching_hinted_output(lengths=lengths, output=lookup)
lengths = _get_unflattened_lengths(lengths, len(embedding_names))
jt = construct_jagged_tensors_inference(
embeddings=lookup,
Expand Down

0 comments on commit 42c512c

Please sign in to comment.