From dc6a78944a64601d1caa8238ff3f00af8e077251 Mon Sep 17 00:00:00 2001 From: Emma Lin Date: Thu, 16 Jan 2025 16:03:17 -0800 Subject: [PATCH] fix ZCH inference input dist error (#2682) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/2682 Serving test failed when ZCH sharding is enabled because embedding lookup request Out of Bound. After compared with training sharding input_dist, found that the keep_orig_idx flag is True when calling training block bucketize kernel, but not set in inference. So this diff is to make that consistent between training and inference, so the ID distribution works the same between training and inference. Reviewed By: kausv Differential Revision: D68123290 fbshipit-source-id: 57813e55fc946dec79a522cfb341974a2ed7669d --- torchrec/distributed/embedding_sharding.py | 4 ++++ torchrec/distributed/mc_modules.py | 1 + torchrec/distributed/sharding/rw_sharding.py | 4 ++++ 3 files changed, 9 insertions(+) diff --git a/torchrec/distributed/embedding_sharding.py b/torchrec/distributed/embedding_sharding.py index 04afb8fd9..38bb0dd4b 100644 --- a/torchrec/distributed/embedding_sharding.py +++ b/torchrec/distributed/embedding_sharding.py @@ -132,6 +132,7 @@ def _fx_wrap_seq_block_bucketize_sparse_features_inference( bucketize_pos: bool = False, block_bucketize_pos: Optional[List[torch.Tensor]] = None, total_num_blocks: Optional[torch.Tensor] = None, + keep_original_indices: bool = False, ) -> Tuple[ torch.Tensor, torch.Tensor, @@ -159,6 +160,7 @@ def _fx_wrap_seq_block_bucketize_sparse_features_inference( max_B=_fx_wrap_max_B(kjt), block_bucketize_pos=block_bucketize_pos, return_bucket_mapping=True, + keep_orig_idx=keep_original_indices, ) return ( @@ -305,6 +307,7 @@ def bucketize_kjt_inference( bucketize_pos: bool = False, block_bucketize_row_pos: Optional[List[torch.Tensor]] = None, is_sequence: bool = False, + keep_original_indices: bool = False, ) -> Tuple[KeyedJaggedTensor, Optional[torch.Tensor], Optional[torch.Tensor]]: """ Bucketizes the `values` in KeyedJaggedTensor into `num_buckets` buckets, @@ -352,6 +355,7 @@ def bucketize_kjt_inference( total_num_blocks=total_num_buckets_new_type, bucketize_pos=bucketize_pos, block_bucketize_pos=block_bucketize_row_pos, + keep_original_indices=keep_original_indices, ) else: ( diff --git a/torchrec/distributed/mc_modules.py b/torchrec/distributed/mc_modules.py index b85e6f9c3..63dbb7a13 100644 --- a/torchrec/distributed/mc_modules.py +++ b/torchrec/distributed/mc_modules.py @@ -1171,6 +1171,7 @@ def _create_input_dists( has_feature_processor=sharding._has_feature_processor, need_pos=False, embedding_shard_metadata=emb_sharding, + keep_original_indices=True, ) self._input_dists.append(input_dist) diff --git a/torchrec/distributed/sharding/rw_sharding.py b/torchrec/distributed/sharding/rw_sharding.py index deac8359b..f61ea0bd8 100644 --- a/torchrec/distributed/sharding/rw_sharding.py +++ b/torchrec/distributed/sharding/rw_sharding.py @@ -649,10 +649,12 @@ def __init__( has_feature_processor: bool = False, need_pos: bool = False, embedding_shard_metadata: Optional[List[List[int]]] = None, + keep_original_indices: bool = False, ) -> None: super().__init__() logger.info( f"InferRwSparseFeaturesDist: {world_size=}, {num_features=}, {feature_hash_sizes=}, {feature_total_num_buckets=}, {device=}, {is_sequence=}, {has_feature_processor=}, {need_pos=}, {embedding_shard_metadata=}" + f", keep_original_indices={keep_original_indices}" ) self._world_size: int = world_size self._num_features = num_features @@ -683,6 +685,7 @@ def __init__( self._embedding_shard_metadata: Optional[List[List[int]]] = ( embedding_shard_metadata ) + self._keep_original_indices = keep_original_indices def forward(self, sparse_features: KeyedJaggedTensor) -> InputDistOutputs: block_sizes, block_bucketize_row_pos = get_block_sizes_runtime_device( @@ -717,6 +720,7 @@ def forward(self, sparse_features: KeyedJaggedTensor) -> InputDistOutputs: block_bucketize_row_pos ), is_sequence=self._is_sequence, + keep_original_indices=self._keep_original_indices, ) # KJTOneToAll dist_kjt = self._dist.forward(bucketized_features)