diff --git a/test/pjrt/test_collective_ops_tpu.py b/test/pjrt/test_collective_ops_tpu.py index 28b0b770906..7ee9e7d8a66 100644 --- a/test/pjrt/test_collective_ops_tpu.py +++ b/test/pjrt/test_collective_ops_tpu.py @@ -103,6 +103,29 @@ def test_reduce_scatter(self, pin_layout): for ordinal, value in results.items(): np.testing.assert_array_equal(value, [-ordinal]) + @staticmethod + def _scatter(): + dist.init_process_group("xla", init_method='xla://') + device = torch_xla.device() + world_size = xr.world_size() + tensors = None + if xr.global_ordinal() == 0: + tensors = [ + torch.tensor([i], device=device, dtype=torch.float) + for i in range(world_size) + ] + + output_tensor = torch.tensor([-1], dtype=torch.float, device=device) + dist.scatter(output_tensor, tensors, src=0) + return output_tensor.cpu() + + def test_scatter(self): + """self._scatter instantiates a list of tensors [[0], [1], ..., [n-1]] + on device 0, then scatters it. Device i should therefore receive [i].""" + results = pjrt.run_multiprocess(self._scatter) + for ordinal, value in results.items(): + np.testing.assert_array_equal(value, [ordinal]) + @staticmethod def _all_to_all(pin_layout): device = torch_xla.device() diff --git a/test/test_torch_distributed_xla_backend.py b/test/test_torch_distributed_xla_backend.py index a3069a6637e..99b721a4fa1 100644 --- a/test/test_torch_distributed_xla_backend.py +++ b/test/test_torch_distributed_xla_backend.py @@ -360,7 +360,6 @@ def test_barrier(self): 'allreduce_coalesced', 'alltoall', 'gather', - 'scatter', 'recv_anysource', 'monitored_barrier', ) diff --git a/torch_xla/distributed/xla_backend.py b/torch_xla/distributed/xla_backend.py index 8905ff81e9f..948e31a08b1 100644 --- a/torch_xla/distributed/xla_backend.py +++ b/torch_xla/distributed/xla_backend.py @@ -5,7 +5,7 @@ from torch_xla._internal import rendezvous import logging import os -from torch._C._distributed_c10d import ProcessGroup, AllgatherOptions +from torch._C._distributed_c10d import ProcessGroup, ScatterOptions, ReduceScatterOptions, AllgatherOptions def _create_xla_process_group(prefix_store, rank, size, timeout): @@ -253,8 +253,22 @@ def alltoall_base(self, output, input, output_split_sizes, input_split_sizes, def gather(self, *args): raise NotImplementedError - def scatter(self, *args): - raise NotImplementedError + # Called by torch.distributed.scatter. Call site example: + # https://github.com/pytorch/pytorch/blob/v2.7.1/torch/distributed/distributed_c10d.py#L4146 + # Input tensors are defined on the source device and scattered + # to the output tensors. + def scatter(self, output_tensor_list: list[torch.Tensor], + input_tensors_list: list[list[torch.Tensor]], + opts: ScatterOptions): + if xr.global_ordinal() == opts.rootRank: + inputs = input_tensors_list + else: + inputs = [[torch.zeros_like(output_tensor)] * xr.world_size() + for output_tensor in output_tensor_list] + + rs_opts = ReduceScatterOptions() + rs_opts.reduceOp = dist.ReduceOp.SUM + return self.reduce_scatter(output_tensor_list, inputs, rs_opts) # Dummy channel id maker. Different backend (TPU, GPU, etc) should replace # the maker with their specific one. See unit test in