Skip to content

Commit

Permalink
Revert D61388755: Add ShardedQuantManagedCollisionEmbeddingCollection
Browse files Browse the repository at this point in the history
Differential Revision:
D61388755

Original commit changeset: d222a9db8842

Original Phabricator Diff: D61388755

fbshipit-source-id: cc8da12183124d7271e0d0f9773730bd7eea3938
  • Loading branch information
Hongtan Sun authored and facebook-github-bot committed Dec 23, 2024
1 parent 464a0e9 commit c15d0bb
Show file tree
Hide file tree
Showing 8 changed files with 89 additions and 1,064 deletions.
18 changes: 0 additions & 18 deletions torchrec/distributed/embedding_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
from torchrec.streamable import Multistreamable


torch.fx.wrap("len")

CACHE_LOAD_FACTOR_STR: str = "cache_load_factor"
Expand All @@ -62,15 +61,6 @@ def _fx_wrap_tensor_to_device_dtype(
return t.to(device=tensor_device_dtype.device, dtype=tensor_device_dtype.dtype)


@torch.fx.wrap
def _fx_wrap_optional_tensor_to_device_dtype(
t: Optional[torch.Tensor], tensor_device_dtype: torch.Tensor
) -> Optional[torch.Tensor]:
if t is None:
return None
return t.to(device=tensor_device_dtype.device, dtype=tensor_device_dtype.dtype)


@torch.fx.wrap
def _fx_wrap_batch_size_per_feature(kjt: KeyedJaggedTensor) -> Optional[torch.Tensor]:
return (
Expand Down Expand Up @@ -131,7 +121,6 @@ def _fx_wrap_seq_block_bucketize_sparse_features_inference(
block_sizes: torch.Tensor,
bucketize_pos: bool = False,
block_bucketize_pos: Optional[List[torch.Tensor]] = None,
total_num_blocks: Optional[torch.Tensor] = None,
) -> Tuple[
torch.Tensor,
torch.Tensor,
Expand All @@ -153,7 +142,6 @@ def _fx_wrap_seq_block_bucketize_sparse_features_inference(
bucketize_pos=bucketize_pos,
sequence=True,
block_sizes=block_sizes,
total_num_blocks=total_num_blocks,
my_size=num_buckets,
weights=kjt.weights_or_none(),
max_B=_fx_wrap_max_B(kjt),
Expand Down Expand Up @@ -301,7 +289,6 @@ def bucketize_kjt_inference(
kjt: KeyedJaggedTensor,
num_buckets: int,
block_sizes: torch.Tensor,
total_num_buckets: Optional[torch.Tensor] = None,
bucketize_pos: bool = False,
block_bucketize_row_pos: Optional[List[torch.Tensor]] = None,
is_sequence: bool = False,
Expand All @@ -316,7 +303,6 @@ def bucketize_kjt_inference(
Args:
num_buckets (int): number of buckets to bucketize the values into.
block_sizes: (torch.Tensor): bucket sizes for the keyed dimension.
total_num_blocks: (Optional[torch.Tensor]): number of blocks per feature, useful for two-level bucketization
bucketize_pos (bool): output the changed position of the bucketized values or
not.
block_bucketize_row_pos (Optional[List[torch.Tensor]]): The offsets of shard size for each feature.
Expand All @@ -332,9 +318,6 @@ def bucketize_kjt_inference(
f"Expecting block sizes for {num_features} features, but {block_sizes.numel()} received.",
)
block_sizes_new_type = _fx_wrap_tensor_to_device_dtype(block_sizes, kjt.values())
total_num_buckets_new_type = _fx_wrap_optional_tensor_to_device_dtype(
total_num_buckets, kjt.values()
)
unbucketize_permute = None
bucket_mapping = None
if is_sequence:
Expand All @@ -349,7 +332,6 @@ def bucketize_kjt_inference(
kjt,
num_buckets=num_buckets,
block_sizes=block_sizes_new_type,
total_num_blocks=total_num_buckets_new_type,
bucketize_pos=bucketize_pos,
block_bucketize_pos=block_bucketize_row_pos,
)
Expand Down
Loading

0 comments on commit c15d0bb

Please sign in to comment.