diff --git a/torchrec/distributed/embedding_sharding.py b/torchrec/distributed/embedding_sharding.py index ec1df069b..e046c76b0 100644 --- a/torchrec/distributed/embedding_sharding.py +++ b/torchrec/distributed/embedding_sharding.py @@ -204,7 +204,6 @@ def bucketize_kjt_before_all2all( kjt: KeyedJaggedTensor, num_buckets: int, block_sizes: torch.Tensor, - total_num_blocks: Optional[torch.Tensor] = None, output_permute: bool = False, bucketize_pos: bool = False, block_bucketize_row_pos: Optional[List[torch.Tensor]] = None, @@ -220,7 +219,6 @@ def bucketize_kjt_before_all2all( Args: num_buckets (int): number of buckets to bucketize the values into. block_sizes: (torch.Tensor): bucket sizes for the keyed dimension. - total_num_blocks: (Optional[torch.Tensor]): number of blocks per feature, useful for two-level bucketization output_permute (bool): output the memory location mapping from the unbucketized values to bucketized values or not. bucketize_pos (bool): output the changed position of the bucketized values or @@ -237,7 +235,7 @@ def bucketize_kjt_before_all2all( block_sizes.numel() == num_features, f"Expecting block sizes for {num_features} features, but {block_sizes.numel()} received.", ) - + block_sizes_new_type = _fx_wrap_tensor_to_device_dtype(block_sizes, kjt.values()) ( bucketized_lengths, bucketized_indices, @@ -249,24 +247,14 @@ def bucketize_kjt_before_all2all( kjt.values(), bucketize_pos=bucketize_pos, sequence=output_permute, - block_sizes=_fx_wrap_tensor_to_device_dtype(block_sizes, kjt.values()), - total_num_blocks=( - _fx_wrap_tensor_to_device_dtype(total_num_blocks, kjt.values()) - if total_num_blocks is not None - else None - ), + block_sizes=block_sizes_new_type, my_size=num_buckets, weights=kjt.weights_or_none(), batch_size_per_feature=_fx_wrap_batch_size_per_feature(kjt), max_B=_fx_wrap_max_B(kjt), - block_bucketize_pos=( - _fx_wrap_tensor_to_device_dtype(block_bucketize_row_pos, kjt.lengths()) - if block_bucketize_row_pos is not None - else None - ), + block_bucketize_pos=block_bucketize_row_pos, # each tensor should have the same dtype as kjt.lengths() keep_orig_idx=keep_original_indices, ) - return ( KeyedJaggedTensor( # duplicate keys will be resolved by AllToAll diff --git a/torchrec/distributed/mc_modules.py b/torchrec/distributed/mc_modules.py index e397ea29b..e513b3e35 100644 --- a/torchrec/distributed/mc_modules.py +++ b/torchrec/distributed/mc_modules.py @@ -389,35 +389,19 @@ def _create_input_dists( input_feature_names: List[str], ) -> None: for sharding, sharding_features in zip( - self._embedding_shardings, - self._sharding_features, + self._embedding_shardings, self._sharding_features ): assert isinstance(sharding, BaseRwEmbeddingSharding) - feature_num_buckets: List[int] = [ - self._managed_collision_modules[self._feature_to_table[f]].buckets() - for f in sharding_features - ] - - input_sizes: List[int] = [ + feature_hash_sizes: List[int] = [ self._managed_collision_modules[self._feature_to_table[f]].input_size() for f in sharding_features ] - feature_hash_sizes: List[int] = [] - feature_total_num_buckets: List[int] = [] - for input_size, num_buckets in zip( - input_sizes, - feature_num_buckets, - ): - feature_hash_sizes.append(input_size) - feature_total_num_buckets.append(num_buckets) - input_dist = RwSparseFeaturesDist( # pyre-ignore [6] pg=sharding._pg, num_features=sharding._get_num_features(), feature_hash_sizes=feature_hash_sizes, - feature_total_num_buckets=feature_total_num_buckets, device=sharding._device, is_sequence=True, has_feature_processor=sharding._has_feature_processor, diff --git a/torchrec/distributed/sharding/rw_sharding.py b/torchrec/distributed/sharding/rw_sharding.py index ce60153c9..ccba69a78 100644 --- a/torchrec/distributed/sharding/rw_sharding.py +++ b/torchrec/distributed/sharding/rw_sharding.py @@ -279,7 +279,6 @@ class RwSparseFeaturesDist(BaseSparseFeaturesDist[KeyedJaggedTensor]): communication. num_features (int): total number of features. feature_hash_sizes (List[int]): hash sizes of features. - feature_total_num_buckets (Optional[List[int]]): total number of buckets, if provided will be >= world size. device (Optional[torch.device]): device on which buffers will be allocated. is_sequence (bool): if this is for a sequence embedding. has_feature_processor (bool): existence of feature processor (ie. position @@ -292,7 +291,6 @@ def __init__( pg: dist.ProcessGroup, num_features: int, feature_hash_sizes: List[int], - feature_total_num_buckets: Optional[List[int]] = None, device: Optional[torch.device] = None, is_sequence: bool = False, has_feature_processor: bool = False, @@ -302,16 +300,10 @@ def __init__( super().__init__() self._world_size: int = pg.size() self._num_features = num_features - - feature_block_sizes: List[int] = [] - - for i, hash_size in enumerate(feature_hash_sizes): - block_divisor = self._world_size - if feature_total_num_buckets is not None: - assert feature_total_num_buckets[i] % self._world_size == 0 - block_divisor = feature_total_num_buckets[i] - feature_block_sizes.append((hash_size + block_divisor - 1) // block_divisor) - + feature_block_sizes = [ + (hash_size + self._world_size - 1) // self._world_size + for hash_size in feature_hash_sizes + ] self.register_buffer( "_feature_block_sizes_tensor", torch.tensor( @@ -319,22 +311,7 @@ def __init__( device=device, dtype=torch.int64, ), - persistent=False, ) - self._has_multiple_blocks_per_shard: bool = ( - feature_total_num_buckets is not None - ) - if self._has_multiple_blocks_per_shard: - self.register_buffer( - "_feature_total_num_blocks_tensor", - torch.tensor( - [feature_total_num_buckets], - device=device, - dtype=torch.int64, - ), - persistent=False, - ) - self._dist = KJTAllToAll( pg=pg, splits=[self._num_features] * self._world_size, @@ -368,11 +345,6 @@ def forward( sparse_features, num_buckets=self._world_size, block_sizes=self._feature_block_sizes_tensor, - total_num_blocks=( - self._feature_total_num_blocks_tensor - if self._has_multiple_blocks_per_shard - else None - ), output_permute=self._is_sequence, bucketize_pos=( self._has_feature_processor diff --git a/torchrec/distributed/tests/test_utils.py b/torchrec/distributed/tests/test_utils.py index 3e299192e..135d7f47a 100644 --- a/torchrec/distributed/tests/test_utils.py +++ b/torchrec/distributed/tests/test_utils.py @@ -336,9 +336,7 @@ def test_kjt_bucketize_before_all2all( block_sizes = torch.tensor(block_sizes_list, dtype=index_type).cuda() block_bucketized_kjt, _ = bucketize_kjt_before_all2all( - kjt=kjt, - num_buckets=world_size, - block_sizes=block_sizes, + kjt, world_size, block_sizes, False, False ) expected_block_bucketized_kjt = block_bucketize_ref( @@ -435,10 +433,7 @@ def test_kjt_bucketize_before_all2all_cpu( """ block_sizes = torch.tensor(block_sizes_list, dtype=index_type) block_bucketized_kjt, _ = bucketize_kjt_before_all2all( - kjt=kjt, - num_buckets=world_size, - block_sizes=block_sizes, - block_bucketize_row_pos=block_bucketize_row_pos, + kjt, world_size, block_sizes, False, False, block_bucketize_row_pos ) expected_block_bucketized_kjt = block_bucketize_ref( diff --git a/torchrec/modules/mc_modules.py b/torchrec/modules/mc_modules.py index cc0302470..2d20bc116 100644 --- a/torchrec/modules/mc_modules.py +++ b/torchrec/modules/mc_modules.py @@ -250,13 +250,6 @@ def input_size(self) -> int: """ pass - @abc.abstractmethod - def buckets(self) -> int: - """ - Returns number of uniform buckets, relevant to resharding - """ - pass - @abc.abstractmethod def validate_state(self) -> None: """ @@ -982,7 +975,6 @@ def __init__( name: Optional[str] = None, output_global_offset: int = 0, # typically not provided by user output_segments: Optional[List[int]] = None, # typically not provided by user - buckets: int = 1, ) -> None: if output_segments is None: output_segments = [output_global_offset, output_global_offset + zch_size] @@ -1008,7 +1000,6 @@ def __init__( self._eviction_policy = eviction_policy self._current_iter: int = -1 - self._buckets = buckets self._init_buffers() ## ------ history info ------ @@ -1311,9 +1302,6 @@ def forward( def output_size(self) -> int: return self._zch_size - def buckets(self) -> int: - return self._buckets - def input_size(self) -> int: return self._input_hash_size @@ -1361,5 +1349,4 @@ def rebuild_with_output_id_range( input_hash_func=self._input_hash_func, output_global_offset=output_id_range[0], output_segments=output_segments, - buckets=len(output_segments) - 1, )