Skip to content

Commit

Permalink
MCH module + Q/SQ EC Test / Bug Fixes (pytorch#2331)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#2331

Initial to do Prod strategy is use unsharded MCH module in front of Q / SQ EC.  General seems ok, but biggest issue:

[2] uneven sharding flag was not respected for rw sequence GPU case, easy fix cc: gnahzg
[3] get_propogate device is bit unf/messy, will cleanup in followup task, but found edgecase wrt cpu path cc: gnahzg

Reviewed By: gnahzg

Differential Revision: D61572421

fbshipit-source-id: 412ed7947a7cde1991518d6a979db3bc7832fc68
  • Loading branch information
dstaay-fb authored and facebook-github-bot committed Aug 22, 2024
1 parent f7e444d commit 6f0ea08
Show file tree
Hide file tree
Showing 4 changed files with 180 additions and 34 deletions.
2 changes: 1 addition & 1 deletion torchrec/distributed/dist_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,7 +676,7 @@ def __init__(
"cpu" if device is None else device.type
) # TODO: replace hardcoded cpu with DEFAULT_DEVICE_TYPE in torchrec.distributed.types when torch package issue resolved
else:
# If no device is provided, use "cuda".
# BUG: device will default to cuda if cpu specified
self._device_type: str = (
device.type
if device is not None and device.type in {"meta", "cuda", "mtia"}
Expand Down
2 changes: 1 addition & 1 deletion torchrec/distributed/embedding_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ def bucketize_kjt_inference(
num_buckets=num_buckets,
block_sizes=block_sizes_new_type,
bucketize_pos=bucketize_pos,
block_bucketize_pos=block_bucketize_row_pos, # each tensor should have the same dtype as kjt.lengths()
block_bucketize_pos=block_bucketize_row_pos,
)
else:
(
Expand Down
17 changes: 5 additions & 12 deletions torchrec/distributed/sharding/rw_sequence_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
)
from torchrec.distributed.sharding.rw_sharding import (
BaseRwEmbeddingSharding,
get_embedding_shard_metadata,
InferRwSparseFeaturesDist,
RwSparseFeaturesDist,
)
Expand All @@ -39,7 +40,6 @@
SequenceShardingContext,
)
from torchrec.distributed.types import Awaitable, CommOp, QuantizedCommCodecs
from torchrec.distributed.utils import none_throws
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor


Expand Down Expand Up @@ -199,16 +199,9 @@ def create_input_dist(
num_features = self._get_num_features()
feature_hash_sizes = self._get_feature_hash_sizes()

emb_sharding = []
for embedding_table_group in self._grouped_embedding_configs_per_rank[0]:
for table in embedding_table_group.embedding_tables:
shard_split_offsets = [
shard.shard_offsets[0]
for shard in none_throws(table.global_metadata).shards_metadata
]
shard_split_offsets.append(none_throws(table.global_metadata).size[0])
emb_sharding.extend([shard_split_offsets] * len(table.embedding_names))

(emb_sharding, is_even_sharding) = get_embedding_shard_metadata(
self._grouped_embedding_configs_per_rank
)
return InferRwSparseFeaturesDist(
world_size=self._world_size,
num_features=num_features,
Expand All @@ -217,7 +210,7 @@ def create_input_dist(
is_sequence=True,
has_feature_processor=self._has_feature_processor,
need_pos=False,
embedding_shard_metadata=emb_sharding,
embedding_shard_metadata=emb_sharding if not is_even_sharding else None,
)

def create_lookup(
Expand Down
Loading

0 comments on commit 6f0ea08

Please sign in to comment.