Skip to content

Support torch.distributed.scatter collective #9365

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jun 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions test/pjrt/test_collective_ops_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,29 @@ def test_reduce_scatter(self, pin_layout):
for ordinal, value in results.items():
np.testing.assert_array_equal(value, [-ordinal])

@staticmethod
def _scatter():
dist.init_process_group("xla", init_method='xla://')
device = torch_xla.device()
world_size = xr.world_size()
tensors = None
if xr.global_ordinal() == 0:
tensors = [
torch.tensor([i], device=device, dtype=torch.float)
for i in range(world_size)
]

output_tensor = torch.tensor([-1], dtype=torch.float, device=device)
dist.scatter(output_tensor, tensors, src=0)
return output_tensor.cpu()

def test_scatter(self):
"""self._scatter instantiates a list of tensors [[0], [1], ..., [n-1]]
on device 0, then scatters it. Device i should therefore receive [i]."""
results = pjrt.run_multiprocess(self._scatter)
for ordinal, value in results.items():
np.testing.assert_array_equal(value, [ordinal])

@staticmethod
def _all_to_all(pin_layout):
device = torch_xla.device()
Expand Down
1 change: 0 additions & 1 deletion test/test_torch_distributed_xla_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,6 @@ def test_barrier(self):
'allreduce_coalesced',
'alltoall',
'gather',
'scatter',
'recv_anysource',
'monitored_barrier',
)
Expand Down
20 changes: 17 additions & 3 deletions torch_xla/distributed/xla_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from torch_xla._internal import rendezvous
import logging
import os
from torch._C._distributed_c10d import ProcessGroup, AllgatherOptions
from torch._C._distributed_c10d import ProcessGroup, ScatterOptions, ReduceScatterOptions, AllgatherOptions


def _create_xla_process_group(prefix_store, rank, size, timeout):
Expand Down Expand Up @@ -253,8 +253,22 @@ def alltoall_base(self, output, input, output_split_sizes, input_split_sizes,
def gather(self, *args):
raise NotImplementedError

def scatter(self, *args):
raise NotImplementedError
# Called by torch.distributed.scatter. Call site example:
# https://github.com/pytorch/pytorch/blob/v2.7.1/torch/distributed/distributed_c10d.py#L4146
# Input tensors are defined on the source device and scattered
# to the output tensors.
def scatter(self, output_tensor_list: list[torch.Tensor],
input_tensors_list: list[list[torch.Tensor]],
opts: ScatterOptions):
if xr.global_ordinal() == opts.rootRank:
inputs = input_tensors_list
else:
inputs = [[torch.zeros_like(output_tensor)] * xr.world_size()
for output_tensor in output_tensor_list]

rs_opts = ReduceScatterOptions()
rs_opts.reduceOp = dist.ReduceOp.SUM
return self.reduce_scatter(output_tensor_list, inputs, rs_opts)

# Dummy channel id maker. Different backend (TPU, GPU, etc) should replace
# the maker with their specific one. See unit test in
Expand Down