From af92456c5d0f67e94d71c21c659671a4474b4078 Mon Sep 17 00:00:00 2001 From: Brendan Folie Date: Mon, 16 Jun 2025 16:04:12 +0000 Subject: [PATCH 1/7] write first draft of scatter implementation and test --- test/pjrt/test_collective_ops_tpu.py | 19 +++++++++++++++++++ torch_xla/distributed/xla_backend.py | 15 ++++++++++++--- 2 files changed, 31 insertions(+), 3 deletions(-) diff --git a/test/pjrt/test_collective_ops_tpu.py b/test/pjrt/test_collective_ops_tpu.py index 614040a81dc7..d6f92fdbeac9 100644 --- a/test/pjrt/test_collective_ops_tpu.py +++ b/test/pjrt/test_collective_ops_tpu.py @@ -103,6 +103,25 @@ def test_reduce_scatter(self, pin_layout): for ordinal, value in results.items(): np.testing.assert_array_equal(value, [-ordinal]) + @staticmethod + def _scatter(): + dist.init_process_group("xla", init_method='xla://') + device = torch_xla.device() + world_size = xr.world_size() + if xr.global_ordinal() == 0: + tensors = [torch.tensor([i], device=device, dtype=torch.float) for i in range(world_size)] + else: + tensors = None + output_tensor = torch.tensor([-1], device=device) + dist.scatter(output_tensor, tensors, src=0) + torch_xla.sync() + return output_tensor.cpu() + + def test_scatter(self): + results = pjrt.run_multiprocess(self._scatter) + for ordinal, value in results.items(): + np.testing.assert_array_equal(value, [ordinal]) + @staticmethod def _all_to_all(pin_layout): device = torch_xla.device() diff --git a/torch_xla/distributed/xla_backend.py b/torch_xla/distributed/xla_backend.py index 7222a7bf3dcd..92c31b482cc2 100644 --- a/torch_xla/distributed/xla_backend.py +++ b/torch_xla/distributed/xla_backend.py @@ -5,7 +5,7 @@ from torch_xla._internal import rendezvous import logging import os -from torch._C._distributed_c10d import ProcessGroup +from torch._C._distributed_c10d import ProcessGroup, ScatterOptions, ReduceScatterOptions def _create_xla_process_group(prefix_store, rank, size, timeout): @@ -232,8 +232,17 @@ def alltoall_base(self, output, input, output_split_sizes, input_split_sizes, def gather(self, *args): raise NotImplementedError - def scatter(self, *args): - raise NotImplementedError + def scatter(self, output_tensor_list: list[torch.Tensor], input_tensors_list: list[list[torch.Tensor]], opts: ScatterOptions): + output_tensor = output_tensor_list[0] + if xr.global_ordinal() == opts.rootRank: + input_tensors = input_tensors_list[0] + else: + input_tensors = [torch.zeros_like(output_tensor)] * xr.world_size() + + rs_opts = ReduceScatterOptions() + rs_opts.reduceOp = dist.ReduceOp.SUM + return self.reduce_scatter([output_tensor], [input_tensors], rs_opts) + # Dummy channel id maker. Different backend (TPU, GPU, etc) should replace # the maker with their specific one. See unit test in From 4940e19d6888a31ddac59338d672a51cb7cfeaba Mon Sep 17 00:00:00 2001 From: Brendan Folie Date: Mon, 16 Jun 2025 16:51:26 +0000 Subject: [PATCH 2/7] format, fix small issues --- test/pjrt/test_collective_ops_tpu.py | 8 +++++--- torch_xla/distributed/xla_backend.py | 5 +++-- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/test/pjrt/test_collective_ops_tpu.py b/test/pjrt/test_collective_ops_tpu.py index d6f92fdbeac9..a34e5169611a 100644 --- a/test/pjrt/test_collective_ops_tpu.py +++ b/test/pjrt/test_collective_ops_tpu.py @@ -109,12 +109,14 @@ def _scatter(): device = torch_xla.device() world_size = xr.world_size() if xr.global_ordinal() == 0: - tensors = [torch.tensor([i], device=device, dtype=torch.float) for i in range(world_size)] + tensors = [ + torch.tensor([i], device=device, dtype=torch.float) + for i in range(world_size) + ] else: tensors = None - output_tensor = torch.tensor([-1], device=device) + output_tensor = torch.tensor([-1], dtype=torch.float, device=device) dist.scatter(output_tensor, tensors, src=0) - torch_xla.sync() return output_tensor.cpu() def test_scatter(self): diff --git a/torch_xla/distributed/xla_backend.py b/torch_xla/distributed/xla_backend.py index 92c31b482cc2..0253099d55d4 100644 --- a/torch_xla/distributed/xla_backend.py +++ b/torch_xla/distributed/xla_backend.py @@ -232,7 +232,9 @@ def alltoall_base(self, output, input, output_split_sizes, input_split_sizes, def gather(self, *args): raise NotImplementedError - def scatter(self, output_tensor_list: list[torch.Tensor], input_tensors_list: list[list[torch.Tensor]], opts: ScatterOptions): + def scatter(self, output_tensor_list: list[torch.Tensor], + input_tensors_list: list[list[torch.Tensor]], + opts: ScatterOptions): output_tensor = output_tensor_list[0] if xr.global_ordinal() == opts.rootRank: input_tensors = input_tensors_list[0] @@ -243,7 +245,6 @@ def scatter(self, output_tensor_list: list[torch.Tensor], input_tensors_list: li rs_opts.reduceOp = dist.ReduceOp.SUM return self.reduce_scatter([output_tensor], [input_tensors], rs_opts) - # Dummy channel id maker. Different backend (TPU, GPU, etc) should replace # the maker with their specific one. See unit test in # test/test_torch_distributed_xla_backend.py for an example. From fcf2803fc9f3cdf1459e3571d8f3165c1797d981 Mon Sep 17 00:00:00 2001 From: Brendan Folie Date: Mon, 16 Jun 2025 17:10:35 +0000 Subject: [PATCH 3/7] generalize implementation to work on longer lists --- torch_xla/distributed/xla_backend.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torch_xla/distributed/xla_backend.py b/torch_xla/distributed/xla_backend.py index 0253099d55d4..1943c06e5962 100644 --- a/torch_xla/distributed/xla_backend.py +++ b/torch_xla/distributed/xla_backend.py @@ -235,15 +235,15 @@ def gather(self, *args): def scatter(self, output_tensor_list: list[torch.Tensor], input_tensors_list: list[list[torch.Tensor]], opts: ScatterOptions): - output_tensor = output_tensor_list[0] if xr.global_ordinal() == opts.rootRank: - input_tensors = input_tensors_list[0] + inputs = input_tensors_list else: - input_tensors = [torch.zeros_like(output_tensor)] * xr.world_size() + inputs = [[torch.zeros_like(output_tensor)] * xr.world_size() + for output_tensor in output_tensor_list] rs_opts = ReduceScatterOptions() rs_opts.reduceOp = dist.ReduceOp.SUM - return self.reduce_scatter([output_tensor], [input_tensors], rs_opts) + return self.reduce_scatter(output_tensor_list, inputs, rs_opts) # Dummy channel id maker. Different backend (TPU, GPU, etc) should replace # the maker with their specific one. See unit test in From 04032e83a086c2e7e5e94b41a1a9f38080922127 Mon Sep 17 00:00:00 2001 From: Brendan Folie Date: Tue, 17 Jun 2025 05:11:48 +0000 Subject: [PATCH 4/7] remove scatter from unimplemented list --- test/test_torch_distributed_xla_backend.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/test_torch_distributed_xla_backend.py b/test/test_torch_distributed_xla_backend.py index a3069a6637ec..99b721a4fa16 100644 --- a/test/test_torch_distributed_xla_backend.py +++ b/test/test_torch_distributed_xla_backend.py @@ -360,7 +360,6 @@ def test_barrier(self): 'allreduce_coalesced', 'alltoall', 'gather', - 'scatter', 'recv_anysource', 'monitored_barrier', ) From 25ffd2ff7301e5a6022c00fc24b52855992e9dc8 Mon Sep 17 00:00:00 2001 From: Brendan Folie Date: Tue, 17 Jun 2025 05:34:27 +0000 Subject: [PATCH 5/7] remove extra blank line --- test/pjrt/test_collective_ops_tpu.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/pjrt/test_collective_ops_tpu.py b/test/pjrt/test_collective_ops_tpu.py index 53b31ed1a158..3cc2f8e7d2d1 100644 --- a/test/pjrt/test_collective_ops_tpu.py +++ b/test/pjrt/test_collective_ops_tpu.py @@ -357,6 +357,5 @@ def test_all_to_all_single(self, use_dynamo): expected.sort().values), f"Got {val}, expected {expected}") - if __name__ == '__main__': absltest.main() From 448bd65e337383eed50b0598bf74c1cac6c8fc8e Mon Sep 17 00:00:00 2001 From: Brendan Folie Date: Wed, 18 Jun 2025 17:32:11 +0000 Subject: [PATCH 6/7] improved documentation --- test/pjrt/test_collective_ops_tpu.py | 4 ++++ torch_xla/distributed/xla_backend.py | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/test/pjrt/test_collective_ops_tpu.py b/test/pjrt/test_collective_ops_tpu.py index 3cc2f8e7d2d1..0d4cd3930a51 100644 --- a/test/pjrt/test_collective_ops_tpu.py +++ b/test/pjrt/test_collective_ops_tpu.py @@ -115,11 +115,14 @@ def _scatter(): ] else: tensors = None + output_tensor = torch.tensor([-1], dtype=torch.float, device=device) dist.scatter(output_tensor, tensors, src=0) return output_tensor.cpu() def test_scatter(self): + """self._scatter instantiates a list of tensors [[0], [1], ..., [n-1]] + on device 0, then scatters it. Device i should therefore receive [i].""" results = pjrt.run_multiprocess(self._scatter) for ordinal, value in results.items(): np.testing.assert_array_equal(value, [ordinal]) @@ -357,5 +360,6 @@ def test_all_to_all_single(self, use_dynamo): expected.sort().values), f"Got {val}, expected {expected}") + if __name__ == '__main__': absltest.main() diff --git a/torch_xla/distributed/xla_backend.py b/torch_xla/distributed/xla_backend.py index b110f58e66bc..948e31a08b1b 100644 --- a/torch_xla/distributed/xla_backend.py +++ b/torch_xla/distributed/xla_backend.py @@ -253,6 +253,10 @@ def alltoall_base(self, output, input, output_split_sizes, input_split_sizes, def gather(self, *args): raise NotImplementedError + # Called by torch.distributed.scatter. Call site example: + # https://github.com/pytorch/pytorch/blob/v2.7.1/torch/distributed/distributed_c10d.py#L4146 + # Input tensors are defined on the source device and scattered + # to the output tensors. def scatter(self, output_tensor_list: list[torch.Tensor], input_tensors_list: list[list[torch.Tensor]], opts: ScatterOptions): From e8db11635768cf4f6e97e50dac7b5291b4041d60 Mon Sep 17 00:00:00 2001 From: Brendan Folie Date: Tue, 24 Jun 2025 07:26:07 -0700 Subject: [PATCH 7/7] Consolidate if-else --- test/pjrt/test_collective_ops_tpu.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test/pjrt/test_collective_ops_tpu.py b/test/pjrt/test_collective_ops_tpu.py index 0d4cd3930a51..7ee9e7d8a66f 100644 --- a/test/pjrt/test_collective_ops_tpu.py +++ b/test/pjrt/test_collective_ops_tpu.py @@ -108,13 +108,12 @@ def _scatter(): dist.init_process_group("xla", init_method='xla://') device = torch_xla.device() world_size = xr.world_size() + tensors = None if xr.global_ordinal() == 0: tensors = [ torch.tensor([i], device=device, dtype=torch.float) for i in range(world_size) ] - else: - tensors = None output_tensor = torch.tensor([-1], dtype=torch.float, device=device) dist.scatter(output_tensor, tensors, src=0)