Skip to content

Commit

Permalink
Add Support for FakeProcessGroup (pytorch#1877)
Browse files Browse the repository at this point in the history
Summary:
X-link: pytorch/pytorch#133039

Pull Request resolved: pytorch#1877

# context
* use FakeProcessGroup to mimic the multi-process tests
* can use `_test_compile_fake_pg_fn` as the single-process VB compile test
```
from torchrec.distributed.tests.test_pt2_multiprocess import _test_compile_fake_pg_fn
_test_compile_fake_pg_fn(
    rank=0,
    world_size=2,
)
```

reference: D59637444

Differential Revision: D56124045
  • Loading branch information
TroyGarden authored and facebook-github-bot committed Aug 8, 2024
1 parent 5f8a495 commit f04d9a6
Show file tree
Hide file tree
Showing 6 changed files with 285 additions and 37 deletions.
32 changes: 20 additions & 12 deletions torchrec/distributed/batched_embedding_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -736,13 +736,17 @@ def __init__(
self._emb_module,
pg,
)
self._param_per_table: Dict[str, nn.Parameter] = dict(
_gen_named_parameters_by_table_ssd(
emb_module=self._emb_module,
table_name_to_count=self.table_name_to_count.copy(),
config=self._config,
pg=pg,
self._param_per_table: Dict[str, nn.Parameter] = (
dict(
_gen_named_parameters_by_table_ssd(
emb_module=self._emb_module,
table_name_to_count=self.table_name_to_count.copy(),
config=self._config,
pg=pg,
)
)
if pg._get_backend_name() != "fake"
else {}
)
self.init_parameters()

Expand Down Expand Up @@ -1308,13 +1312,17 @@ def __init__(
self._emb_module,
pg,
)
self._param_per_table: Dict[str, TableBatchedEmbeddingSlice] = dict(
_gen_named_parameters_by_table_fused(
emb_module=self._emb_module,
table_name_to_count=self.table_name_to_count.copy(),
config=self._config,
pg=pg,
self._param_per_table: Dict[str, TableBatchedEmbeddingSlice] = (
dict(
_gen_named_parameters_by_table_fused(
emb_module=self._emb_module,
table_name_to_count=self.table_name_to_count.copy(),
config=self._config,
pg=pg,
)
)
if pg._get_backend_name() != "fake"
else {}
)
self.init_parameters()

Expand Down
31 changes: 21 additions & 10 deletions torchrec/distributed/comm_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,7 @@ def variable_batch_alltoall_pooled(
if group is None:
group = dist.distributed_c10d._get_default_group()

if dist.get_world_size(group) <= 1:
if group.size() <= 1:
return NoWait(a2a_pooled_embs_tensor)

a2ai = VariableBatchAll2AllPooledInfo(
Expand Down Expand Up @@ -509,7 +509,7 @@ def variable_batch_all2all_pooled_sync(
my_rank = pg.rank()

# get input splits
world_size = dist.get_world_size(pg)
world_size = pg.size()
input_split_sizes = [0 for _ in range(world_size)]
if a2ai.batch_size_per_rank_per_feature:
for i in range(world_size):
Expand Down Expand Up @@ -559,14 +559,25 @@ def variable_batch_all2all_pooled_sync(
]

with record_function("## alltoall_fwd_single ##"):
sharded_output_embeddings = torch.ops.torchrec.all_to_all_single(
sharded_input_embeddings,
output_split_sizes,
input_split_sizes,
pg_name(pg),
pg.size(),
get_gradient_division(),
)
if pg._get_backend_name() == "fake":
sharded_output_embeddings = torch.empty(
sum(output_split_sizes),
device=sharded_input_embeddings.device,
dtype=sharded_input_embeddings.dtype,
)
s0 = sharded_output_embeddings.size(0)
# Bad assumption that our rank GE than other
torch._check(s0 <= sharded_input_embeddings.size(0))
sharded_output_embeddings.copy_(sharded_input_embeddings[:s0])
else:
sharded_output_embeddings = torch.ops.torchrec.all_to_all_single(
sharded_input_embeddings,
output_split_sizes,
input_split_sizes,
pg_name(pg),
pg.size(),
get_gradient_division(),
)

if a2ai.codecs is not None:
codecs = none_throws(a2ai.codecs)
Expand Down
56 changes: 43 additions & 13 deletions torchrec/distributed/dist_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,12 +239,25 @@ def __init__(
# https://github.com/pytorch/pytorch/issues/122788
with record_function("## all2all_data:kjt splits ##"):
input_tensor = torch.stack(input_tensors, dim=1).flatten()
self._output_tensor = dist._functional_collectives.all_to_all_single(
input_tensor,
output_split_sizes=None,
input_split_sizes=None,
group=pg,
)
if pg._get_backend_name() == "fake":
self._output_tensor = torch.empty(
[self.num_workers * len(input_tensors)],
device=input_tensors[0].device,
dtype=input_tensors[0].dtype,
)

self._output_tensor = input_tensor[
: input_tensor.size(0) // 2
].repeat(2)
else:
self._output_tensor = (
dist._functional_collectives.all_to_all_single(
input_tensor,
output_split_sizes=None,
input_split_sizes=None,
group=pg,
)
)
# To avoid hasattr in _wait_impl to check self._splits_awaitable
# pyre-ignore
self._splits_awaitable = None
Expand Down Expand Up @@ -342,6 +355,7 @@ def __init__(
self._output_tensors: List[torch.Tensor] = []
self._awaitables: List[dist.Work] = []
self._world_size: int = self._pg.size()
rank = self._pg.rank()

for input_split, output_split, input_tensor, label in zip(
input_splits,
Expand All @@ -353,12 +367,28 @@ def __init__(
# TODO(ivankobzarev) Remove this dynamo condition once dynamo functional collectives remapping does not emit copy_
# https://github.com/pytorch/pytorch/issues/122788
with record_function(f"## all2all_data:kjt {label} ##"):
output_tensor = dist._functional_collectives.all_to_all_single(
input_tensor,
output_split,
input_split,
pg,
)
if self._pg._get_backend_name() == "fake":
output_tensor = torch.empty(
sum(output_split),
device=self._device,
dtype=input_tensor.dtype,
)
_l = sum(output_split[:rank])
_r = _l + output_split[rank]
torch._check(_r < input_tensor.size(0))
torch._check(_l < input_tensor.size(0))
torch._check(_l <= _r)
torch._check(2 * (_r - _l) == output_tensor.size(0))
output_tensor.copy_(
input_tensor[_l:_r].repeat(self._world_size)
)
else:
output_tensor = dist._functional_collectives.all_to_all_single(
input_tensor,
output_split,
input_split,
pg,
)
self._output_tensors.append(output_tensor)
else:
output_tensor = torch.empty(
Expand Down Expand Up @@ -599,7 +629,7 @@ def forward(

with torch.no_grad():
assert len(input.keys()) == sum(self._splits)
rank = dist.get_rank(self._pg)
rank = self._pg.rank()
local_keys = input.keys()[
self._splits_cumsum[rank] : self._splits_cumsum[rank + 1]
]
Expand Down
4 changes: 3 additions & 1 deletion torchrec/distributed/embeddingbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,7 +657,9 @@ def __init__(
broadcast_buffers=True,
static_graph=True,
)
self._initialize_torch_state()

if env.process_group._get_backend_name() != "fake":
self._initialize_torch_state()

# TODO[zainhuda]: support module device coming from CPU
if module.device not in ["meta", "cpu"] and module.device.type not in [
Expand Down
Loading

0 comments on commit f04d9a6

Please sign in to comment.