From 31aec12ab643665f03eba492dedd7f2ecf33497e Mon Sep 17 00:00:00 2001 From: Yong Hoon Shin Date: Mon, 16 Dec 2024 10:39:33 -0800 Subject: [PATCH] Remove unused alltoallv function (#2630) Summary: This code isn't referenced or used anywhere. Reviewed By: PaulZhang12 Differential Revision: D67110533 --- torchrec/distributed/comm_ops.py | 109 ------------------------ torchrec/distributed/tests/test_comm.py | 78 ----------------- 2 files changed, 187 deletions(-) diff --git a/torchrec/distributed/comm_ops.py b/torchrec/distributed/comm_ops.py index cffcdecda..856f50d10 100644 --- a/torchrec/distributed/comm_ops.py +++ b/torchrec/distributed/comm_ops.py @@ -739,115 +739,6 @@ def all2all_sequence_sync( return sharded_output_embeddings.view(-1, D) -def alltoallv( - inputs: List[Tensor], - out_split: Optional[List[int]] = None, - per_rank_split_lengths: Optional[List[int]] = None, - group: Optional[dist.ProcessGroup] = None, - codecs: Optional[QuantizedCommCodecs] = None, -) -> Awaitable[List[Tensor]]: - """ - Performs `alltoallv` operation for a list of input embeddings. Each process scatters - the list to all processes in the group. - - Args: - inputs (List[Tensor]): list of tensors to scatter, one per rank. The tensors in - the list usually have different lengths. - out_split (Optional[List[int]]): output split sizes (or dim_sum_per_rank), if - not specified, we will use `per_rank_split_lengths` to construct a output - split with the assumption that all the embs have the same dimension. - per_rank_split_lengths (Optional[List[int]]): split lengths per rank. If not - specified, the `out_split` must be specified. - group (Optional[dist.ProcessGroup]): the process group to work on. If None, the - default process group will be used. - codecs (Optional[QuantizedCommCodecs]): quantized communication codecs. - - Returns: - Awaitable[List[Tensor]]: async work handle (`Awaitable`), which can be `wait()` later to get the resulting list of tensors. - - .. warning:: - `alltoallv` is experimental and subject to change. - """ - - if group is None: - group = dist.distributed_c10d._get_default_group() - - world_size: int = group.size() - my_rank: int = group.rank() - - B_global = inputs[0].size(0) - - D_local_list = [] - for e in inputs: - D_local_list.append(e.size()[1]) - - B_local, B_local_list = _get_split_lengths_by_len(world_size, my_rank, B_global) - - if out_split is not None: - dims_sum_per_rank = out_split - elif per_rank_split_lengths is not None: - # all the embs have the same dimension - dims_sum_per_rank = [] - for s in per_rank_split_lengths: - dims_sum_per_rank.append(s * D_local_list[0]) - else: - raise RuntimeError("Need to specify either out_split or per_rank_split_lengths") - - a2ai = All2AllVInfo( - dims_sum_per_rank=dims_sum_per_rank, - B_local=B_local, - B_local_list=B_local_list, - D_local_list=D_local_list, - B_global=B_global, - codecs=codecs, - ) - - if get_use_sync_collectives(): - return NoWait(all2allv_sync(group, a2ai, inputs)) - - myreq = Request(group, device=inputs[0].device) - All2Allv_Req.apply(group, myreq, a2ai, inputs) - - return myreq - - -def all2allv_sync( - pg: dist.ProcessGroup, - a2ai: All2AllVInfo, - inputs: List[Tensor], -) -> List[Tensor]: - input_split_sizes = [] - sum_D_local_list = sum(a2ai.D_local_list) - for m in a2ai.B_local_list: - input_split_sizes.append(m * sum_D_local_list) - - output_split_sizes = [] - for e in a2ai.dims_sum_per_rank: - output_split_sizes.append(a2ai.B_local * e) - - input = torch.cat(inputs, dim=1).view([-1]) - if a2ai.codecs is not None: - input = a2ai.codecs.forward.encode(input) - - with record_function("## alltoallv_bwd_single ##"): - output = torch.ops.torchrec.all_to_all_single( - input, - output_split_sizes, - input_split_sizes, - pg_name(pg), - pg.size(), - get_gradient_division(), - ) - - if a2ai.codecs is not None: - output = a2ai.codecs.forward.decode(output) - - outputs = [] - for out in output.split(output_split_sizes): - outputs.append(out.view([a2ai.B_local, -1])) - return outputs - - def reduce_scatter_pooled( inputs: List[Tensor], group: Optional[dist.ProcessGroup] = None, diff --git a/torchrec/distributed/tests/test_comm.py b/torchrec/distributed/tests/test_comm.py index 02dbf02f3..d110e9740 100644 --- a/torchrec/distributed/tests/test_comm.py +++ b/torchrec/distributed/tests/test_comm.py @@ -204,84 +204,6 @@ def _run_multi_process_test( p.join() self.assertEqual(0, p.exitcode) - @classmethod - def _test_alltoallv( - cls, - rank: int, - world_size: int, - backend: str, - compile_config: _CompileConfig, - specify_pg: bool, - ) -> None: - dist.init_process_group(rank=rank, world_size=world_size, backend=backend) - pg = GroupMember.WORLD - assert pg is not None - - device = torch.device(f"cuda:{rank}") - - torch.cuda.set_device(device) - - B_global = 10 - D0 = 8 - D1 = 9 - - input_embedding0 = torch.rand( - (B_global, D0), - device=device, - requires_grad=True, - ) - input_embedding1 = torch.rand( - (B_global, D1), - device=device, - requires_grad=True, - ) - - input_embeddings = [input_embedding0, input_embedding1] - out_split = [17, 17] - - # pyre-ignore - def fn(*args, **kwargs) -> List[torch.Tensor]: - return comm_ops.alltoallv(*args, **kwargs).wait() - - fn_transform = compile_config_to_fn_transform(compile_config) - - with unittest.mock.patch( - "torch._dynamo.config.skip_torchrec", - False, - ): - v_embs_out = fn_transform(fn)( - input_embeddings, out_split=out_split, group=pg if specify_pg else None - ) - - res = torch.cat(v_embs_out, dim=1).cpu() - assert tuple(res.size()) == (5, 34) - dist.destroy_process_group() - - @unittest.skipIf( - torch.cuda.device_count() < 2, "Need at least two ranks to run this test" - ) - # pyre-ignore - @given( - specify_pg=st.sampled_from([True]), - test_compiled_with_noncompiled_ranks=st.sampled_from([False, True]), - ) - @settings(deadline=None) - def test_alltoallv( - self, - specify_pg: bool, - test_compiled_with_noncompiled_ranks: bool, - ) -> None: - self._run_multi_process_test( - world_size=self.WORLD_SIZE, - backend="nccl", - # pyre-ignore [6] - callable=self._test_alltoallv, - compile_config=_CompileConfig( - test_compiled_with_noncompiled_ranks=test_compiled_with_noncompiled_ranks - ), - specify_pg=specify_pg, - ) - @classmethod def _test_alltoall_sequence( cls,