Skip to content

Commit

Permalink
option to output DTensor in state dict through fused_params (pytorch#…
Browse files Browse the repository at this point in the history
…2277)

Summary:
Pull Request resolved: pytorch#2277

Users can optionally receive DTensor in state dict as opposed to ShardedTensor through fused_param setting, default is false.
```
fused_params: {"output_dtensor": True}
```

This diff only enable it for RW, these paths will be added for subsequent sharding schemes in their respective diff such as for CW/TWCW: D57063512

Reviewed By: joshuadeng

Differential Revision: D60932501

fbshipit-source-id: 97daafa106cc833bda4c67b751f390cc4f337f66
  • Loading branch information
iamzainhuda authored and facebook-github-bot committed Aug 8, 2024
1 parent 26d4244 commit f7c1ca1
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 5 deletions.
4 changes: 2 additions & 2 deletions torchrec/distributed/embedding_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -901,9 +901,9 @@ class EmbeddingSharding(abc.ABC, Generic[C, F, T, W], FeatureShardingMixIn):
"""

def __init__(
self, qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None
self,
qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None,
) -> None:

self._qcomm_codecs_registry = qcomm_codecs_registry

@property
Expand Down
17 changes: 14 additions & 3 deletions torchrec/distributed/sharding/rw_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,7 @@ def __init__(
need_pos: bool = False,
qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None,
) -> None:
super().__init__(
qcomm_codecs_registry=qcomm_codecs_registry,
)
super().__init__(qcomm_codecs_registry=qcomm_codecs_registry)

self._env = env
self._pg: Optional[dist.ProcessGroup] = self._env.process_group
Expand Down Expand Up @@ -166,6 +164,18 @@ def _shard(
),
)

dtensor_metadata = None
if info.fused_params.get("output_dtensor", False): # pyre-ignore[16]
dtensor_metadata = DTensorMetadata(
mesh=self._env.device_mesh,
placements=(Shard(0),),
size=(
info.embedding_config.num_embeddings,
info.embedding_config.embedding_dim,
),
stride=info.param.stride(),
)

for rank in range(self._world_size):
tables_per_rank[rank].append(
ShardedEmbeddingTable(
Expand All @@ -185,6 +195,7 @@ def _shard(
),
local_metadata=shards[rank],
global_metadata=global_metadata,
dtensor_metadata=dtensor_metadata,
weight_init_max=info.embedding_config.weight_init_max,
weight_init_min=info.embedding_config.weight_init_min,
fused_params=info.fused_params,
Expand Down

0 comments on commit f7c1ca1

Please sign in to comment.