diff --git a/torchrec/distributed/embedding_sharding.py b/torchrec/distributed/embedding_sharding.py index 04afb8fd9..38bb0dd4b 100644 --- a/torchrec/distributed/embedding_sharding.py +++ b/torchrec/distributed/embedding_sharding.py @@ -132,6 +132,7 @@ def _fx_wrap_seq_block_bucketize_sparse_features_inference( bucketize_pos: bool = False, block_bucketize_pos: Optional[List[torch.Tensor]] = None, total_num_blocks: Optional[torch.Tensor] = None, + keep_original_indices: bool = False, ) -> Tuple[ torch.Tensor, torch.Tensor, @@ -159,6 +160,7 @@ def _fx_wrap_seq_block_bucketize_sparse_features_inference( max_B=_fx_wrap_max_B(kjt), block_bucketize_pos=block_bucketize_pos, return_bucket_mapping=True, + keep_orig_idx=keep_original_indices, ) return ( @@ -305,6 +307,7 @@ def bucketize_kjt_inference( bucketize_pos: bool = False, block_bucketize_row_pos: Optional[List[torch.Tensor]] = None, is_sequence: bool = False, + keep_original_indices: bool = False, ) -> Tuple[KeyedJaggedTensor, Optional[torch.Tensor], Optional[torch.Tensor]]: """ Bucketizes the `values` in KeyedJaggedTensor into `num_buckets` buckets, @@ -352,6 +355,7 @@ def bucketize_kjt_inference( total_num_blocks=total_num_buckets_new_type, bucketize_pos=bucketize_pos, block_bucketize_pos=block_bucketize_row_pos, + keep_original_indices=keep_original_indices, ) else: ( diff --git a/torchrec/distributed/mc_modules.py b/torchrec/distributed/mc_modules.py index b85e6f9c3..63dbb7a13 100644 --- a/torchrec/distributed/mc_modules.py +++ b/torchrec/distributed/mc_modules.py @@ -1171,6 +1171,7 @@ def _create_input_dists( has_feature_processor=sharding._has_feature_processor, need_pos=False, embedding_shard_metadata=emb_sharding, + keep_original_indices=True, ) self._input_dists.append(input_dist) diff --git a/torchrec/distributed/sharding/rw_sharding.py b/torchrec/distributed/sharding/rw_sharding.py index deac8359b..f61ea0bd8 100644 --- a/torchrec/distributed/sharding/rw_sharding.py +++ b/torchrec/distributed/sharding/rw_sharding.py @@ -649,10 +649,12 @@ def __init__( has_feature_processor: bool = False, need_pos: bool = False, embedding_shard_metadata: Optional[List[List[int]]] = None, + keep_original_indices: bool = False, ) -> None: super().__init__() logger.info( f"InferRwSparseFeaturesDist: {world_size=}, {num_features=}, {feature_hash_sizes=}, {feature_total_num_buckets=}, {device=}, {is_sequence=}, {has_feature_processor=}, {need_pos=}, {embedding_shard_metadata=}" + f", keep_original_indices={keep_original_indices}" ) self._world_size: int = world_size self._num_features = num_features @@ -683,6 +685,7 @@ def __init__( self._embedding_shard_metadata: Optional[List[List[int]]] = ( embedding_shard_metadata ) + self._keep_original_indices = keep_original_indices def forward(self, sparse_features: KeyedJaggedTensor) -> InputDistOutputs: block_sizes, block_bucketize_row_pos = get_block_sizes_runtime_device( @@ -717,6 +720,7 @@ def forward(self, sparse_features: KeyedJaggedTensor) -> InputDistOutputs: block_bucketize_row_pos ), is_sequence=self._is_sequence, + keep_original_indices=self._keep_original_indices, ) # KJTOneToAll dist_kjt = self._dist.forward(bucketized_features)