Skip to content

Commit

Permalink
simple fx rule for get length tensor (pytorch#1767)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#1767

ATT

Reviewed By: jingsh, jiayisuse

Differential Revision: D54603545

fbshipit-source-id: 534466ada1ba9161de3e8e2459a7aae417d59daf
  • Loading branch information
YazhiGao authored and facebook-github-bot committed Mar 8, 2024
1 parent 1d6ce32 commit 1df82cc
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions torchrec/quant/embedding_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,11 @@ def _get_batching_hinted_output(lengths: Tensor, output: Tensor) -> Tensor:
return output


@torch.fx.wrap
def _get_feature_length(feature: KeyedJaggedTensor) -> Tensor:
return feature.lengths()


def for_each_module_of_type_do(
module: nn.Module,
module_types: List[Type[torch.nn.Module]],
Expand Down Expand Up @@ -863,9 +868,8 @@ def forward(
):
f = kjts_per_key[i]
indices = f.values()
lengths = f.lengths()
lengths = _get_feature_length(f)
offsets = f.offsets()
lengths = f.lengths()
lookup = (
emb_module(indices=indices, offsets=offsets)
if self.register_tbes
Expand Down

0 comments on commit 1df82cc

Please sign in to comment.