Skip to content

Commit 01db65d

Browse files
authored
Support torch.distributed.scatter collective (#9365)
1 parent 2d81349 commit 01db65d

File tree

3 files changed

+40
-4
lines changed

3 files changed

+40
-4
lines changed

test/pjrt/test_collective_ops_tpu.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,29 @@ def test_reduce_scatter(self, pin_layout):
103103
for ordinal, value in results.items():
104104
np.testing.assert_array_equal(value, [-ordinal])
105105

106+
@staticmethod
107+
def _scatter():
108+
dist.init_process_group("xla", init_method='xla://')
109+
device = torch_xla.device()
110+
world_size = xr.world_size()
111+
tensors = None
112+
if xr.global_ordinal() == 0:
113+
tensors = [
114+
torch.tensor([i], device=device, dtype=torch.float)
115+
for i in range(world_size)
116+
]
117+
118+
output_tensor = torch.tensor([-1], dtype=torch.float, device=device)
119+
dist.scatter(output_tensor, tensors, src=0)
120+
return output_tensor.cpu()
121+
122+
def test_scatter(self):
123+
"""self._scatter instantiates a list of tensors [[0], [1], ..., [n-1]]
124+
on device 0, then scatters it. Device i should therefore receive [i]."""
125+
results = pjrt.run_multiprocess(self._scatter)
126+
for ordinal, value in results.items():
127+
np.testing.assert_array_equal(value, [ordinal])
128+
106129
@staticmethod
107130
def _all_to_all(pin_layout):
108131
device = torch_xla.device()

test/test_torch_distributed_xla_backend.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,6 @@ def test_barrier(self):
360360
'allreduce_coalesced',
361361
'alltoall',
362362
'gather',
363-
'scatter',
364363
'recv_anysource',
365364
'monitored_barrier',
366365
)

torch_xla/distributed/xla_backend.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from torch_xla._internal import rendezvous
66
import logging
77
import os
8-
from torch._C._distributed_c10d import ProcessGroup, AllgatherOptions
8+
from torch._C._distributed_c10d import ProcessGroup, ScatterOptions, ReduceScatterOptions, AllgatherOptions
99

1010

1111
def _create_xla_process_group(prefix_store, rank, size, timeout):
@@ -253,8 +253,22 @@ def alltoall_base(self, output, input, output_split_sizes, input_split_sizes,
253253
def gather(self, *args):
254254
raise NotImplementedError
255255

256-
def scatter(self, *args):
257-
raise NotImplementedError
256+
# Called by torch.distributed.scatter. Call site example:
257+
# https://github.com/pytorch/pytorch/blob/v2.7.1/torch/distributed/distributed_c10d.py#L4146
258+
# Input tensors are defined on the source device and scattered
259+
# to the output tensors.
260+
def scatter(self, output_tensor_list: list[torch.Tensor],
261+
input_tensors_list: list[list[torch.Tensor]],
262+
opts: ScatterOptions):
263+
if xr.global_ordinal() == opts.rootRank:
264+
inputs = input_tensors_list
265+
else:
266+
inputs = [[torch.zeros_like(output_tensor)] * xr.world_size()
267+
for output_tensor in output_tensor_list]
268+
269+
rs_opts = ReduceScatterOptions()
270+
rs_opts.reduceOp = dist.ReduceOp.SUM
271+
return self.reduce_scatter(output_tensor_list, inputs, rs_opts)
258272

259273
# Dummy channel id maker. Different backend (TPU, GPU, etc) should replace
260274
# the maker with their specific one. See unit test in

0 commit comments

Comments
 (0)