diff --git a/test/pjrt/test_collective_ops_tpu.py b/test/pjrt/test_collective_ops_tpu.py index 614040a81dc..a95cf0078af 100644 --- a/test/pjrt/test_collective_ops_tpu.py +++ b/test/pjrt/test_collective_ops_tpu.py @@ -326,6 +326,28 @@ def test_all_to_all_single(self, use_dynamo): expected.sort().values), f"Got {val}, expected {expected}") + @staticmethod + def _send_recv(): + dist.init_process_group("xla", init_method='xla://') + device = torch_xla.device() + world_size = xr.world_size() + cutoff = world_size // 2 + index = xr.global_ordinal() + tensor = torch.tensor([index + 1], dtype=torch.float, device=device) + if index < cutoff: + dist.send(tensor, index + cutoff) + else: + dist.recv(tensor, index - cutoff) + return tensor.cpu() + + def test_send_recv(self): + """Send tensors on first N/2 devices to second N/2 devices.""" + results = pjrt.run_multiprocess(self._send_recv) + world_size = tpu.num_expected_global_devices() + for ordinal, value in results.items(): + expected = ordinal + 1 if ordinal < world_size // 2 else ordinal + 1 - world_size // 2 + np.testing.assert_array_equal(value, [expected]) + if __name__ == '__main__': absltest.main() diff --git a/test/test_torch_distributed_xla_backend.py b/test/test_torch_distributed_xla_backend.py index a3069a6637e..f064ba88c2c 100644 --- a/test/test_torch_distributed_xla_backend.py +++ b/test/test_torch_distributed_xla_backend.py @@ -166,47 +166,6 @@ def test_reduce_scatter_coalesced(self): # purge all computations attached the device. torch_xla.sync() - @patch_world(0, 6) - def test_send(self): - device = torch_xla.device() - tensor = torch.arange(2, device=device) + 1 + 2 * dist.get_rank() - input_list = [tensor] - - with mock.patch.object( - torch_xla.distributed.xla_backend.ProcessGroupXla, - 'make_send_channel_id', - new=lambda self, dst_rank, tag: dst_rank * 2): - dist.send(tensor, 1) - - send_pattern = r'%send\.\d+ = .+ send\(.+\), channel_id=2' - senddone_pattern = r'%send\-done\.\d+ = .+ send\-done\(.+\), channel_id=2' - hlo = torch_xla._XLAC._get_xla_tensors_hlo([tensor]) - hlo_matches(hlo, send_pattern) - hlo_matches(hlo, senddone_pattern) - - # Don't try to run Send on CPU because it's not implemented - torch_xla._XLAC._clear_pending_irs(str(torch_xla.device())) - - @patch_world(0, 6) - def test_recv(self): - device = torch_xla.device() - tensor = torch.arange(2, device=device) + 1 + 2 * dist.get_rank() - - with mock.patch.object( - torch_xla.distributed.xla_backend.ProcessGroupXla, - 'make_recv_channel_id', - new=lambda self, src_rank, tag: src_rank * 3): - dist.recv(tensor, 1) - - recv_pattern = r'%recv\.\d+ = .+ recv\(.+\), channel_id=3' - recvdone_pattern = r'%recv\-done\.\d+ = .+ recv\-done\(.+\), channel_id=3' - hlo = torch_xla._XLAC._get_xla_tensors_hlo([tensor]) - hlo_matches(hlo, recv_pattern) - hlo_matches(hlo, recvdone_pattern) - - # Don't try to run Recv on CPU because it's not implemented - torch_xla._XLAC._clear_pending_irs(str(torch_xla.device())) - @patch_world(rank=0, size=12) def test_new_group_no_ranks(self): with new_group_barrier_disabled(): diff --git a/torch_xla/core/xla_model.py b/torch_xla/core/xla_model.py index 6b68e656d33..3dbad1a963e 100644 --- a/torch_xla/core/xla_model.py +++ b/torch_xla/core/xla_model.py @@ -748,9 +748,6 @@ def collective_permute(value: torch.Tensor, pairs: List[List[int]]) -> torch.Tensor: """Performs a XLA `CollectivePermute()` operation on the input tensor. - WARNING: This function is not very reliable, may produce wrong results under - certain inputs. Use it at your own risk. - See: https://www.tensorflow.org/xla/operation_semantics#collectivepermute Args: diff --git a/torch_xla/distributed/xla_backend.py b/torch_xla/distributed/xla_backend.py index 7222a7bf3dc..99992ca4b22 100644 --- a/torch_xla/distributed/xla_backend.py +++ b/torch_xla/distributed/xla_backend.py @@ -235,41 +235,36 @@ def gather(self, *args): def scatter(self, *args): raise NotImplementedError - # Dummy channel id maker. Different backend (TPU, GPU, etc) should replace - # the maker with their specific one. See unit test in - # test/test_torch_distributed_xla_backend.py for an example. - def make_send_channel_id(self, dst_rank, tag): - raise NotImplementedError - # Call site e.g. # https://github.com/pytorch/pytorch/blob/release/1.10/torch/distributed/distributed_c10d.py#L877 def send(self, tensors, dst_rank, tag=0): + logging.warning( + "Individual send/recv ops are inefficient on an XLA device. Consider using xla_model.collective_permute()." + ) results = [] for t in tensors: - channel_id = self.make_send_channel_id(dst_rank, tag) - # The input will be returned as result. - input_as_result = xm.send(t, channel_id) - # Make the sent tensor depend on the token, such that the `send` - # op can actually be built into the computation graph. + result_t = xm.collective_permute( + t, pairs=[[xr.global_ordinal(), dst_rank]]) + # Every process must have the same IR, otherwise they deadlock. But in + # the receiving process the provided tensor receives the result, while + # in the sending process it is unchanged. The solution used here is to + # have every process copy a linear combination of the two tensors, but + # send/recv use different coefficients to achieve different outcomes. with torch.no_grad(): - t.copy_(input_as_result) - results.append(input_as_result) + t.copy_(result_t * 0.0 + t * 1.0) + results.append(result_t) return _ret_work(results) - # Dummy channel id maker. Different backend (TPU, GPU, etc) should replace - # the maker with their specific one. See unit test in - # test/test_torch_distributed_xla_backend.py for an example. - def make_recv_channel_id(self, src_rank, tag): - raise NotImplementedError - # Call site e.g. # https://github.com/pytorch/pytorch/blob/release/1.10/torch/distributed/distributed_c10d.py#L913 def recv(self, out_tensors, src_rank, tag=0): results = [] for ot in out_tensors: - channel_id = self.make_recv_channel_id(src_rank, tag) - result = xm.recv(ot, channel_id) - results.append(result) + result_t = xm.collective_permute( + ot, pairs=[[src_rank, xr.global_ordinal()]]) + with torch.no_grad(): + ot.copy_(result_t * 1.0 + ot * 0.0) + results.append(result_t) return _ret_work(results) def recv_anysource(self, *args):