Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Back out "Re-shardable Hash Zch" #2552

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 3 additions & 15 deletions torchrec/distributed/embedding_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,6 @@ def bucketize_kjt_before_all2all(
kjt: KeyedJaggedTensor,
num_buckets: int,
block_sizes: torch.Tensor,
total_num_blocks: Optional[torch.Tensor] = None,
output_permute: bool = False,
bucketize_pos: bool = False,
block_bucketize_row_pos: Optional[List[torch.Tensor]] = None,
Expand All @@ -220,7 +219,6 @@ def bucketize_kjt_before_all2all(
Args:
num_buckets (int): number of buckets to bucketize the values into.
block_sizes: (torch.Tensor): bucket sizes for the keyed dimension.
total_num_blocks: (Optional[torch.Tensor]): number of blocks per feature, useful for two-level bucketization
output_permute (bool): output the memory location mapping from the unbucketized
values to bucketized values or not.
bucketize_pos (bool): output the changed position of the bucketized values or
Expand All @@ -237,7 +235,7 @@ def bucketize_kjt_before_all2all(
block_sizes.numel() == num_features,
f"Expecting block sizes for {num_features} features, but {block_sizes.numel()} received.",
)

block_sizes_new_type = _fx_wrap_tensor_to_device_dtype(block_sizes, kjt.values())
(
bucketized_lengths,
bucketized_indices,
Expand All @@ -249,24 +247,14 @@ def bucketize_kjt_before_all2all(
kjt.values(),
bucketize_pos=bucketize_pos,
sequence=output_permute,
block_sizes=_fx_wrap_tensor_to_device_dtype(block_sizes, kjt.values()),
total_num_blocks=(
_fx_wrap_tensor_to_device_dtype(total_num_blocks, kjt.values())
if total_num_blocks is not None
else None
),
block_sizes=block_sizes_new_type,
my_size=num_buckets,
weights=kjt.weights_or_none(),
batch_size_per_feature=_fx_wrap_batch_size_per_feature(kjt),
max_B=_fx_wrap_max_B(kjt),
block_bucketize_pos=(
_fx_wrap_tensor_to_device_dtype(block_bucketize_row_pos, kjt.lengths())
if block_bucketize_row_pos is not None
else None
),
block_bucketize_pos=block_bucketize_row_pos, # each tensor should have the same dtype as kjt.lengths()
keep_orig_idx=keep_original_indices,
)

return (
KeyedJaggedTensor(
# duplicate keys will be resolved by AllToAll
Expand Down
20 changes: 2 additions & 18 deletions torchrec/distributed/mc_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,35 +389,19 @@ def _create_input_dists(
input_feature_names: List[str],
) -> None:
for sharding, sharding_features in zip(
self._embedding_shardings,
self._sharding_features,
self._embedding_shardings, self._sharding_features
):
assert isinstance(sharding, BaseRwEmbeddingSharding)
feature_num_buckets: List[int] = [
self._managed_collision_modules[self._feature_to_table[f]].buckets()
for f in sharding_features
]

input_sizes: List[int] = [
feature_hash_sizes: List[int] = [
self._managed_collision_modules[self._feature_to_table[f]].input_size()
for f in sharding_features
]

feature_hash_sizes: List[int] = []
feature_total_num_buckets: List[int] = []
for input_size, num_buckets in zip(
input_sizes,
feature_num_buckets,
):
feature_hash_sizes.append(input_size)
feature_total_num_buckets.append(num_buckets)

input_dist = RwSparseFeaturesDist(
# pyre-ignore [6]
pg=sharding._pg,
num_features=sharding._get_num_features(),
feature_hash_sizes=feature_hash_sizes,
feature_total_num_buckets=feature_total_num_buckets,
device=sharding._device,
is_sequence=True,
has_feature_processor=sharding._has_feature_processor,
Expand Down
36 changes: 4 additions & 32 deletions torchrec/distributed/sharding/rw_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,6 @@ class RwSparseFeaturesDist(BaseSparseFeaturesDist[KeyedJaggedTensor]):
communication.
num_features (int): total number of features.
feature_hash_sizes (List[int]): hash sizes of features.
feature_total_num_buckets (Optional[List[int]]): total number of buckets, if provided will be >= world size.
device (Optional[torch.device]): device on which buffers will be allocated.
is_sequence (bool): if this is for a sequence embedding.
has_feature_processor (bool): existence of feature processor (ie. position
Expand All @@ -292,7 +291,6 @@ def __init__(
pg: dist.ProcessGroup,
num_features: int,
feature_hash_sizes: List[int],
feature_total_num_buckets: Optional[List[int]] = None,
device: Optional[torch.device] = None,
is_sequence: bool = False,
has_feature_processor: bool = False,
Expand All @@ -302,39 +300,18 @@ def __init__(
super().__init__()
self._world_size: int = pg.size()
self._num_features = num_features

feature_block_sizes: List[int] = []

for i, hash_size in enumerate(feature_hash_sizes):
block_divisor = self._world_size
if feature_total_num_buckets is not None:
assert feature_total_num_buckets[i] % self._world_size == 0
block_divisor = feature_total_num_buckets[i]
feature_block_sizes.append((hash_size + block_divisor - 1) // block_divisor)

feature_block_sizes = [
(hash_size + self._world_size - 1) // self._world_size
for hash_size in feature_hash_sizes
]
self.register_buffer(
"_feature_block_sizes_tensor",
torch.tensor(
feature_block_sizes,
device=device,
dtype=torch.int64,
),
persistent=False,
)
self._has_multiple_blocks_per_shard: bool = (
feature_total_num_buckets is not None
)
if self._has_multiple_blocks_per_shard:
self.register_buffer(
"_feature_total_num_blocks_tensor",
torch.tensor(
[feature_total_num_buckets],
device=device,
dtype=torch.int64,
),
persistent=False,
)

self._dist = KJTAllToAll(
pg=pg,
splits=[self._num_features] * self._world_size,
Expand Down Expand Up @@ -368,11 +345,6 @@ def forward(
sparse_features,
num_buckets=self._world_size,
block_sizes=self._feature_block_sizes_tensor,
total_num_blocks=(
self._feature_total_num_blocks_tensor
if self._has_multiple_blocks_per_shard
else None
),
output_permute=self._is_sequence,
bucketize_pos=(
self._has_feature_processor
Expand Down
9 changes: 2 additions & 7 deletions torchrec/distributed/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,9 +336,7 @@ def test_kjt_bucketize_before_all2all(
block_sizes = torch.tensor(block_sizes_list, dtype=index_type).cuda()

block_bucketized_kjt, _ = bucketize_kjt_before_all2all(
kjt=kjt,
num_buckets=world_size,
block_sizes=block_sizes,
kjt, world_size, block_sizes, False, False
)

expected_block_bucketized_kjt = block_bucketize_ref(
Expand Down Expand Up @@ -435,10 +433,7 @@ def test_kjt_bucketize_before_all2all_cpu(
"""
block_sizes = torch.tensor(block_sizes_list, dtype=index_type)
block_bucketized_kjt, _ = bucketize_kjt_before_all2all(
kjt=kjt,
num_buckets=world_size,
block_sizes=block_sizes,
block_bucketize_row_pos=block_bucketize_row_pos,
kjt, world_size, block_sizes, False, False, block_bucketize_row_pos
)

expected_block_bucketized_kjt = block_bucketize_ref(
Expand Down
13 changes: 0 additions & 13 deletions torchrec/modules/mc_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,13 +250,6 @@ def input_size(self) -> int:
"""
pass

@abc.abstractmethod
def buckets(self) -> int:
"""
Returns number of uniform buckets, relevant to resharding
"""
pass

@abc.abstractmethod
def validate_state(self) -> None:
"""
Expand Down Expand Up @@ -982,7 +975,6 @@ def __init__(
name: Optional[str] = None,
output_global_offset: int = 0, # typically not provided by user
output_segments: Optional[List[int]] = None, # typically not provided by user
buckets: int = 1,
) -> None:
if output_segments is None:
output_segments = [output_global_offset, output_global_offset + zch_size]
Expand All @@ -1008,7 +1000,6 @@ def __init__(
self._eviction_policy = eviction_policy

self._current_iter: int = -1
self._buckets = buckets
self._init_buffers()

## ------ history info ------
Expand Down Expand Up @@ -1311,9 +1302,6 @@ def forward(
def output_size(self) -> int:
return self._zch_size

def buckets(self) -> int:
return self._buckets

def input_size(self) -> int:
return self._input_hash_size

Expand Down Expand Up @@ -1361,5 +1349,4 @@ def rebuild_with_output_id_range(
input_hash_func=self._input_hash_func,
output_global_offset=output_id_range[0],
output_segments=output_segments,
buckets=len(output_segments) - 1,
)
Loading