Skip to content

implement send and recv using collective_permute #9373

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions test/pjrt/test_collective_ops_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,28 @@ def test_all_to_all_single(self, use_dynamo):
expected.sort().values),
f"Got {val}, expected {expected}")

@staticmethod
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Last time we checked, we also noticed that https://github.com/pytorch/xla/blob/master/test/test_mp_collective_permute.py didn't work on the CPU, but send/recv did. We might want to double check it.

Is test/test_torch_distributed_xla_backend.py tested for CPU and Neuron? Would it be possible to test it and see if the change is compatible?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is test/test_torch_distributed_xla_backend.py tested for CPU and Neuron? Would it be possible to test it and see if the change is compatible?

It is, but it just checks that the expected IR is emitted. It doesn't run anything. And in this case it wasn't a reliable test because, at least for TPU, that IR does not actually run.

test_mp_collective_permute is run for both TPU and Neuron. I don't think it works for CPU but neither do send/recv. The success of test_mp_collective_permute indicates this change should work for Neuron, but to be more certain I could add a test that covers a pipeline-like transfer in addition to the existing test of a permutation-like transfer.

The most direct test would be something like what's in test_collective_ops_tpu.py, which runs the ops to completion, for Neuron.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The most direct test would be something like what's in test_collective_ops_tpu.py, which runs the ops to completion, for Neuron.

This would be great. Any chance we can move it outside of this file and make it general? I can help test it out if so. Otherwise, I'll need to follow up if we can port this entire file to Neuron. I see tpu.num_expected_global_devices, and pjrt.run_multiprocess, but haven't seen/used these before.

def _send_recv():
dist.init_process_group("xla", init_method='xla://')
device = torch_xla.device()
world_size = xr.world_size()
cutoff = world_size // 2
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think if the world size is not even, this test will hang. For example, if world size is 3, then index 0 will send to 1 and 1 will recv from 0, but index 2 will try to recv from 1 without an associated send.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. I'll update the test so that it is more defensive

index = xr.global_ordinal()
tensor = torch.tensor([index + 1], dtype=torch.float, device=device)
if index < cutoff:
dist.send(tensor, index + cutoff)
else:
dist.recv(tensor, index - cutoff)
return tensor.cpu()

def test_send_recv(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The original test separated both send and receive. While this is more code efficient, it might be harder to debug as it will not be obvious what the issue is.

I think keeping a test for the total interaction is valid, but is there a way to replicate the other two tests that existed previously?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

send and recv don't work independently. The original test was a "dry run" -- it checked the IR but didn't execute. If it did execute it would fail.

"""Send tensors on first N/2 devices to second N/2 devices."""
results = pjrt.run_multiprocess(self._send_recv)
world_size = tpu.num_expected_global_devices()
for ordinal, value in results.items():
expected = ordinal + 1 if ordinal < world_size // 2 else ordinal + 1 - world_size // 2
np.testing.assert_array_equal(value, [expected])


if __name__ == '__main__':
absltest.main()
41 changes: 0 additions & 41 deletions test/test_torch_distributed_xla_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,47 +166,6 @@ def test_reduce_scatter_coalesced(self):
# purge all computations attached the device.
torch_xla.sync()

@patch_world(0, 6)
def test_send(self):
device = torch_xla.device()
tensor = torch.arange(2, device=device) + 1 + 2 * dist.get_rank()
input_list = [tensor]

with mock.patch.object(
torch_xla.distributed.xla_backend.ProcessGroupXla,
'make_send_channel_id',
new=lambda self, dst_rank, tag: dst_rank * 2):
dist.send(tensor, 1)

send_pattern = r'%send\.\d+ = .+ send\(.+\), channel_id=2'
senddone_pattern = r'%send\-done\.\d+ = .+ send\-done\(.+\), channel_id=2'
hlo = torch_xla._XLAC._get_xla_tensors_hlo([tensor])
hlo_matches(hlo, send_pattern)
hlo_matches(hlo, senddone_pattern)

# Don't try to run Send on CPU because it's not implemented
torch_xla._XLAC._clear_pending_irs(str(torch_xla.device()))

@patch_world(0, 6)
def test_recv(self):
device = torch_xla.device()
tensor = torch.arange(2, device=device) + 1 + 2 * dist.get_rank()

with mock.patch.object(
torch_xla.distributed.xla_backend.ProcessGroupXla,
'make_recv_channel_id',
new=lambda self, src_rank, tag: src_rank * 3):
dist.recv(tensor, 1)

recv_pattern = r'%recv\.\d+ = .+ recv\(.+\), channel_id=3'
recvdone_pattern = r'%recv\-done\.\d+ = .+ recv\-done\(.+\), channel_id=3'
hlo = torch_xla._XLAC._get_xla_tensors_hlo([tensor])
hlo_matches(hlo, recv_pattern)
hlo_matches(hlo, recvdone_pattern)

# Don't try to run Recv on CPU because it's not implemented
torch_xla._XLAC._clear_pending_irs(str(torch_xla.device()))

@patch_world(rank=0, size=12)
def test_new_group_no_ranks(self):
with new_group_barrier_disabled():
Expand Down
3 changes: 0 additions & 3 deletions torch_xla/core/xla_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,9 +748,6 @@ def collective_permute(value: torch.Tensor,
pairs: List[List[int]]) -> torch.Tensor:
"""Performs a XLA `CollectivePermute()` operation on the input tensor.

WARNING: This function is not very reliable, may produce wrong results under
certain inputs. Use it at your own risk.

Comment on lines -751 to -753
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed in #8815 there's no context for this ancient warning. Given the age, lack of details, and lack of any other reported bugs I think it's best to remove it. If we get a specific bug report then we can act on that.

See: https://www.tensorflow.org/xla/operation_semantics#collectivepermute

Args:
Expand Down
39 changes: 17 additions & 22 deletions torch_xla/distributed/xla_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,41 +235,36 @@ def gather(self, *args):
def scatter(self, *args):
raise NotImplementedError

# Dummy channel id maker. Different backend (TPU, GPU, etc) should replace
# the maker with their specific one. See unit test in
# test/test_torch_distributed_xla_backend.py for an example.
def make_send_channel_id(self, dst_rank, tag):
raise NotImplementedError

# Call site e.g.
# https://github.com/pytorch/pytorch/blob/release/1.10/torch/distributed/distributed_c10d.py#L877
def send(self, tensors, dst_rank, tag=0):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we're warning to use collective_permute, but it still ends up using a collective permute, should the warning itself be clearer that this is happening under the hood?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could word this better. The real advice is to restructure your code so that each process calls collective_permute with all of the send-recv pairs

logging.warning(
"Individual send/recv ops are inefficient on an XLA device. Consider using xla_model.collective_permute()."
)
Comment on lines +241 to +243
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it happen to print it everytime we trace?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably. I'm not sure how to only make it print once -- will look into it

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked around, and couldn't find a in built way to do this through logging.warning. Given this is at warning level and can be filtered out, is it worth to seek a solution?

results = []
for t in tensors:
channel_id = self.make_send_channel_id(dst_rank, tag)
# The input will be returned as result.
input_as_result = xm.send(t, channel_id)
# Make the sent tensor depend on the token, such that the `send`
# op can actually be built into the computation graph.
result_t = xm.collective_permute(
t, pairs=[[xr.global_ordinal(), dst_rank]])
# Every process must have the same IR, otherwise they deadlock. But in
# the receiving process the provided tensor receives the result, while
# in the sending process it is unchanged. The solution used here is to
# have every process copy a linear combination of the two tensors, but
# send/recv use different coefficients to achieve different outcomes.
Comment on lines +250 to +252
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This took a couple reads until I understood what was going on here. My understanding is that by having both result_t * X + t * Y you are having both operation IRs be the same as X and Y are constants. That way when the IRs are compared they will be equivalent.

If this understanding is correct, could you add a little bit more here to make it more apparent?

with torch.no_grad():
t.copy_(input_as_result)
results.append(input_as_result)
t.copy_(result_t * 0.0 + t * 1.0)
results.append(result_t)
return _ret_work(results)

# Dummy channel id maker. Different backend (TPU, GPU, etc) should replace
# the maker with their specific one. See unit test in
# test/test_torch_distributed_xla_backend.py for an example.
def make_recv_channel_id(self, src_rank, tag):
raise NotImplementedError

# Call site e.g.
# https://github.com/pytorch/pytorch/blob/release/1.10/torch/distributed/distributed_c10d.py#L913
def recv(self, out_tensors, src_rank, tag=0):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need the warning on the recv end too, so each host has it?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should not assume someone reading "recv" will have read the documentation for "send". I think we should add documentation here. I would then add a note specific about what the IR expectation will be for "send" and "recv" on each of their comments.

results = []
for ot in out_tensors:
channel_id = self.make_recv_channel_id(src_rank, tag)
result = xm.recv(ot, channel_id)
results.append(result)
result_t = xm.collective_permute(
ot, pairs=[[src_rank, xr.global_ordinal()]])
with torch.no_grad():
ot.copy_(result_t * 1.0 + ot * 0.0)
results.append(result_t)
return _ret_work(results)

def recv_anysource(self, *args):
Expand Down