Skip to content

Commit

Permalink
add 2D parallel and DTensor support to TWRW (pytorch#2629)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#2629

Add 2D and DTensor support for TWRW

Differential Revision: D67145321

fbshipit-source-id: eaa264131628c89fa5900fd2df9dc10b1d3da893
  • Loading branch information
iamzainhuda authored and facebook-github-bot committed Dec 13, 2024
1 parent 3928a1b commit 019a92d
Show file tree
Hide file tree
Showing 2 changed files with 143 additions and 8 deletions.
64 changes: 56 additions & 8 deletions torchrec/distributed/sharding/twrw_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,13 @@

import torch
import torch.distributed as dist
from torchrec.distributed.comm import get_local_size, intra_and_cross_node_pg
from torch.distributed._tensor import Shard
from torch.distributed.distributed_c10d import get_process_group_ranks
from torchrec.distributed.comm import (
get_local_size,
intra_and_cross_node_pg,
intra_and_cross_node_pg_2D,
)
from torchrec.distributed.dist_data import (
KJTAllToAll,
PooledEmbeddingsAllToAll,
Expand All @@ -34,6 +40,7 @@
)
from torchrec.distributed.embedding_types import (
BaseGroupedFeatureProcessor,
DTensorMetadata,
EmbeddingComputeKernel,
GroupedEmbeddingConfig,
ShardedEmbeddingTable,
Expand All @@ -44,6 +51,7 @@
QuantizedCommCodecs,
ShardedTensorMetadata,
ShardingEnv,
ShardingEnv2D,
ShardingType,
ShardMetadata,
)
Expand Down Expand Up @@ -71,14 +79,26 @@ def __init__(
) -> None:
super().__init__(qcomm_codecs_registry=qcomm_codecs_registry)
self._env = env
self._pg: Optional[dist.ProcessGroup] = self._env.process_group
self._is_2D_parallel: bool = isinstance(env, ShardingEnv2D)
self._pg: Optional[dist.ProcessGroup] = (
self._env.sharding_pg # pyre-ignore[16]
if self._is_2D_parallel
else self._env.process_group
)
self._world_size: int = self._env.world_size
self._rank: int = self._env.rank
self._device = device
self._need_pos = need_pos
intra_pg, cross_pg = intra_and_cross_node_pg(
device, backend=dist.get_backend(self._pg)
)
if self._is_2D_parallel:
intra_pg, cross_pg = intra_and_cross_node_pg_2D(
# pyre-fixme[6]
self._env,
device=device,
)
else:
intra_pg, cross_pg = intra_and_cross_node_pg(
device, backend=dist.get_backend(self._pg)
)
self._intra_pg: Optional[dist.ProcessGroup] = intra_pg
self._cross_pg: Optional[dist.ProcessGroup] = cross_pg
self._local_size: int = (
Expand Down Expand Up @@ -112,11 +132,23 @@ def _shard(
world_size = self._world_size
local_size = self._local_size
tables_per_rank: List[List[ShardedEmbeddingTable]] = [
[] for i in range(world_size)
[] for _ in range(world_size)
]
peer_group = (
# pyre-ignore [6]
get_process_group_ranks(self._pg)
if self._is_2D_parallel
else None
)
for info in sharding_infos:
# pyre-ignore [16]
table_node = info.param_sharding.ranks[0] // local_size
# Under 2D parallelism we transform rank to the logical ordering in a regular parallelism scheme
rank = (
# pyre-ignore [16]
peer_group.index(info.param_sharding.ranks[0])
if peer_group is not None
else info.param_sharding.ranks[0]
)
table_node = rank // local_size
# pyre-fixme [16]
shards = info.param_sharding.sharding_spec.shards

Expand All @@ -131,6 +163,21 @@ def _shard(
),
)

dtensor_metadata = None
if info.fused_params.get("output_dtensor", False): # pyre-ignore[16]
placements = (Shard(0),)
dtensor_metadata = DTensorMetadata(
mesh=self._env.device_mesh,
placements=placements,
size=(
info.embedding_config.num_embeddings,
info.embedding_config.embedding_dim,
),
stride=info.param.stride(),
)
# to not pass onto TBE
info.fused_params.pop("output_dtensor", None) # pyre-ignore[16]

for rank in range(
table_node * local_size,
(table_node + 1) * local_size,
Expand All @@ -154,6 +201,7 @@ def _shard(
),
local_metadata=shards[rank_idx],
global_metadata=global_metadata,
dtensor_metadata=dtensor_metadata,
weight_init_max=info.embedding_config.weight_init_max,
weight_init_min=info.embedding_config.weight_init_min,
fused_params=info.fused_params,
Expand Down
87 changes: 87 additions & 0 deletions torchrec/distributed/tests/test_2d_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,3 +402,90 @@ def test_sharding_rw_2D(
variable_batch_size=variable_batch_size,
pooling=pooling,
)

@unittest.skipIf(
torch.cuda.device_count() <= 7,
"Not enough GPUs, this test requires at least four GPUs",
)
# pyre-fixme[56]
@given(
sharder_type=st.sampled_from(
[
SharderType.EMBEDDING_BAG_COLLECTION.value,
]
),
kernel_type=st.sampled_from(
[
EmbeddingComputeKernel.FUSED.value,
EmbeddingComputeKernel.FUSED_UVM_CACHING.value,
EmbeddingComputeKernel.FUSED_UVM.value,
],
),
qcomms_config=st.sampled_from(
[
# None,
QCommsConfig(
forward_precision=CommType.FP16, backward_precision=CommType.BF16
),
]
),
apply_optimizer_in_backward_config=st.sampled_from(
[
None,
{
"embedding_bags": (
torch.optim.SGD,
{
"lr": 0.01,
},
),
},
]
),
pooling=st.sampled_from([PoolingType.SUM]),
)
@settings(verbosity=Verbosity.verbose, max_examples=1, deadline=None)
def test_sharding_twrw_2D(
self,
sharder_type: str,
kernel_type: str,
qcomms_config: Optional[QCommsConfig],
apply_optimizer_in_backward_config: Optional[
Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]]
],
pooling: PoolingType,
) -> None:
if (
self.device == torch.device("cpu")
and kernel_type != EmbeddingComputeKernel.FUSED.value
):
self.skipTest("CPU does not support uvm.")

sharding_type = ShardingType.TABLE_ROW_WISE.value
assume(sharder_type == SharderType.EMBEDDING_BAG_COLLECTION.value)

self._test_sharding(
world_size=self.WORLD_SIZE,
local_size=self.WORLD_SIZE_2D // 2,
world_size_2D=self.WORLD_SIZE_2D,
sharders=[
cast(
ModuleSharder[nn.Module],
create_test_sharder(
sharder_type,
sharding_type,
kernel_type,
qcomms_config=qcomms_config,
device=self.device,
),
),
],
qcomms_config=qcomms_config,
constraints={
table.name: ParameterConstraints(min_partition=2)
for table in self.tables
},
backend=self.backend,
apply_optimizer_in_backward_config=apply_optimizer_in_backward_config,
pooling=pooling,
)

0 comments on commit 019a92d

Please sign in to comment.