diff --git a/torchrec/quant/embedding_modules.py b/torchrec/quant/embedding_modules.py index c27c45955..803c2df8e 100644 --- a/torchrec/quant/embedding_modules.py +++ b/torchrec/quant/embedding_modules.py @@ -81,6 +81,11 @@ def _get_batching_hinted_output(lengths: Tensor, output: Tensor) -> Tensor: return output +@torch.fx.wrap +def _get_feature_length(feature: KeyedJaggedTensor) -> Tensor: + return feature.lengths() + + def for_each_module_of_type_do( module: nn.Module, module_types: List[Type[torch.nn.Module]], @@ -863,9 +868,8 @@ def forward( ): f = kjts_per_key[i] indices = f.values() - lengths = f.lengths() + lengths = _get_feature_length(f) offsets = f.offsets() - lengths = f.lengths() lookup = ( emb_module(indices=indices, offsets=offsets) if self.register_tbes