Skip to content

Commit

Permalink
fix ZCH inference input dist error (#2682)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2682

Serving test failed when ZCH sharding is enabled because embedding lookup request Out of Bound.

After compared with training sharding input_dist, found that the keep_orig_idx flag is True when calling training block bucketize kernel, but not set in inference. So this diff is to make that consistent between training and inference, so the ID distribution works the same between training and inference.

Reviewed By: kausv

Differential Revision: D68123290

fbshipit-source-id: 57813e55fc946dec79a522cfb341974a2ed7669d
  • Loading branch information
emlin authored and facebook-github-bot committed Jan 17, 2025
1 parent 8cf154d commit dc6a789
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 0 deletions.
4 changes: 4 additions & 0 deletions torchrec/distributed/embedding_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def _fx_wrap_seq_block_bucketize_sparse_features_inference(
bucketize_pos: bool = False,
block_bucketize_pos: Optional[List[torch.Tensor]] = None,
total_num_blocks: Optional[torch.Tensor] = None,
keep_original_indices: bool = False,
) -> Tuple[
torch.Tensor,
torch.Tensor,
Expand Down Expand Up @@ -159,6 +160,7 @@ def _fx_wrap_seq_block_bucketize_sparse_features_inference(
max_B=_fx_wrap_max_B(kjt),
block_bucketize_pos=block_bucketize_pos,
return_bucket_mapping=True,
keep_orig_idx=keep_original_indices,
)

return (
Expand Down Expand Up @@ -305,6 +307,7 @@ def bucketize_kjt_inference(
bucketize_pos: bool = False,
block_bucketize_row_pos: Optional[List[torch.Tensor]] = None,
is_sequence: bool = False,
keep_original_indices: bool = False,
) -> Tuple[KeyedJaggedTensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
"""
Bucketizes the `values` in KeyedJaggedTensor into `num_buckets` buckets,
Expand Down Expand Up @@ -352,6 +355,7 @@ def bucketize_kjt_inference(
total_num_blocks=total_num_buckets_new_type,
bucketize_pos=bucketize_pos,
block_bucketize_pos=block_bucketize_row_pos,
keep_original_indices=keep_original_indices,
)
else:
(
Expand Down
1 change: 1 addition & 0 deletions torchrec/distributed/mc_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -1171,6 +1171,7 @@ def _create_input_dists(
has_feature_processor=sharding._has_feature_processor,
need_pos=False,
embedding_shard_metadata=emb_sharding,
keep_original_indices=True,
)
self._input_dists.append(input_dist)

Expand Down
4 changes: 4 additions & 0 deletions torchrec/distributed/sharding/rw_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,10 +649,12 @@ def __init__(
has_feature_processor: bool = False,
need_pos: bool = False,
embedding_shard_metadata: Optional[List[List[int]]] = None,
keep_original_indices: bool = False,
) -> None:
super().__init__()
logger.info(
f"InferRwSparseFeaturesDist: {world_size=}, {num_features=}, {feature_hash_sizes=}, {feature_total_num_buckets=}, {device=}, {is_sequence=}, {has_feature_processor=}, {need_pos=}, {embedding_shard_metadata=}"
f", keep_original_indices={keep_original_indices}"
)
self._world_size: int = world_size
self._num_features = num_features
Expand Down Expand Up @@ -683,6 +685,7 @@ def __init__(
self._embedding_shard_metadata: Optional[List[List[int]]] = (
embedding_shard_metadata
)
self._keep_original_indices = keep_original_indices

def forward(self, sparse_features: KeyedJaggedTensor) -> InputDistOutputs:
block_sizes, block_bucketize_row_pos = get_block_sizes_runtime_device(
Expand Down Expand Up @@ -717,6 +720,7 @@ def forward(self, sparse_features: KeyedJaggedTensor) -> InputDistOutputs:
block_bucketize_row_pos
),
is_sequence=self._is_sequence,
keep_original_indices=self._keep_original_indices,
)
# KJTOneToAll
dist_kjt = self._dist.forward(bucketized_features)
Expand Down

0 comments on commit dc6a789

Please sign in to comment.