Description
🚀 Feature
Improve coverage of the PyTorch collective operations so that native distributed code is more likely to work without any user modification required. Collective ops would ideally work both in LazyTensor and in compiled mode, but many do not have an upstream Dynamo path.
The following are in scope.
Collectives that operate on objects instead of tensors are out of scope.
Pitch
All Gather Into Tensor
Gathers tensors into a single output tensor. Like the already-implemented all_gather but outputs a single tensor instead of a list of tensors. It appears in FSDP (example), and other places.
This is implemented by wrapping all_gather, so it fails in some cases because it has not been set up to support both stacking and concatenation of the input tensors. This zip, which is meant for a list of tensors, happens to work in some cases by splitting the tensors along the 0th dimension.
This should be a simple refactor, isolating the common logic but letting stacking/concatenation/copying to individual tensors happens only in the correct cases.
Done in #9332
Broadcast
This has been solved, but PR 7956 and PT PR 135171 were not merged. They should be revived. Broadcast is used in DDP.
Scatter
Scatters a list of tensors across processes. It is used by sharding code. Scatter
does not have a corresponding XLA op, but we could implement it similarly to how broadcast is implemented, and perform a reduce_scatter
after multiplying all but the non-source device tensors times 0.
Done in #9365
Reduce Scatter
Reduces a list of tensors, then scatters across processes. This is like the already-implemented reduce_scatter_tensor, but it acts on a list of tensors instead of a single tensor. It is not used in the torch.distributed code but may be necessary to implement scatter
.
Reduce Scatter works in LazyTensor mode, but fails when compiled because there's no Dynamo mapping. There isn't even a reduce_scatter
functional collective but reduce_scatter_tensor_coalesced might be usable. Once implemented we would bind it in pt/xla (the analogous binding for reduce_scatter_tensor
is here). After implementing the binding the logic can be shared with the existing reduce_scatter_tensor, and the underlying XLA Op can accept a tuple of arrays, so the rest should be straight-forward.
Send and Recv
Send and receive are implemented but do not work. As seen in #8074, the XLA ops are missing "_xla_send_recv_source_target_pairs". Those fields get set here. XLA's Send and Recv aren't meant to be called directly. Instead the user is expected to use CollectivePermute and specify every source-target pair. If there are no cycles then the HLO is decomposed into Send and Recv ops.
We can replace send/recv with calls to xm.collective_permute..
Done in #9373
Gather
This is not used in the torch.distributed code but was requested in #9069. It could be implemented using all_gather
and only keeping the result on the dst
rank.
Reduce
This is not used in the torch.distributed code. It could be implemented using all_reduce
and only keeping the result on the dst
rank.
All to All
This is not used in the torch.distributed code. Since all_to_all_single is implemented we could probably stack the tensors, run all_to_all
, then chunk them.