Skip to content

Commit 42c512c

Browse files
seanx92facebook-github-bot
authored andcommitted
add option for MODULE_ATTR_USE_BATCHING_HINTED_OUTPUT (#2544)
Summary: Pull Request resolved: #2544 The current GPU rebatching logics for _get_batching_hinted_output does not work for pooling_factor > 1, we need to disable it for GPU model to avoid rebatching flattened feature length Reviewed By: PaulZhang12 Differential Revision: D65499768 fbshipit-source-id: a450efa87d9e2b3e7cab4a3fcdbd94f99214e4cd
1 parent 509b0d2 commit 42c512c

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

torchrec/quant/embedding_modules.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,8 @@
8585
"__use_unflattened_lengths_for_batching"
8686
)
8787

88+
MODULE_ATTR_USE_BATCHING_HINTED_OUTPUT: str = "__use_batching_hinted_output"
89+
8890
DEFAULT_ROW_ALIGNMENT = 16
8991

9092

@@ -913,7 +915,8 @@ def forward(
913915
lengths = _get_unflattened_lengths(lengths, len(embedding_names))
914916
lookup = _get_batching_hinted_output(lengths=lengths, output=lookup)
915917
else:
916-
lookup = _get_batching_hinted_output(lengths=lengths, output=lookup)
918+
if getattr(self, MODULE_ATTR_USE_BATCHING_HINTED_OUTPUT, True):
919+
lookup = _get_batching_hinted_output(lengths=lengths, output=lookup)
917920
lengths = _get_unflattened_lengths(lengths, len(embedding_names))
918921
jt = construct_jagged_tensors_inference(
919922
embeddings=lookup,

0 commit comments

Comments
 (0)