-
Notifications
You must be signed in to change notification settings - Fork 512
[RFC] MPMD+SPMD Pipeline Parallelism #9019
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Thank you for sharing this RFC! We are reviewing it. cc @lsy323 @tengyifei @haifeng-jin @bhavya01 @pgmoka @GleasonK |
We had minimally prototyped it with XLA + SPMD - mimicking communication across heterogeneous SPMD worlds, minus the native distributed pipelining APIs:
The latter 2 were Neuron-specific, but I'll adapt and share as part of the RFC as well as suggested by Siyuan. |
Thanks for the RFC! This seems quite a lot of work and many details need to be filled in, e.g., how to support profiling. Got a few questions:
On the top level. I don't quite understand the benefit of the SPMD+MPMD solution in the RFC compared with the MPMD solution. I think leave everything to the XLA compiler should be much easier and more efficient for overlapping communication overhead. Can you elaborate more on the motivation? Thanks! |
Thanks @zpcore!
In the RFC above, we started focusing on local SPMD worlds that have same number of participants and localized SPMD meshes. Communication occurs across SPMD local ranks: For example, with 2 stages and 8 devices each, we'd have P2P communication as [[0, 8], [1, 9], ..., [7, 15]], transferring sharded data across stages (with receiving rank using device data IR placeholders). We're also careful not to make decisions that would prevent future extension to heterogeneous local SPMD worlds (different meshes or participant counts) - which would require synchronization before or after stages. On the top level question, capturing all different variants for the wider picture:
We don't limit opportunities to optimize communication overlaps in the XLA compiler, as we can leverage insights from both model strategies. This will be addressed in OpenXLA RFC follow-ups, with clear separation between SPMD and MPMD communication collective operations, which can conceptually help with execution sequence dependencies, communication overlapping, and optimal computation scheduling. This work can be adapted when supporting MPMD+SPMD strategies, even without pipeline parallelism (e.g., heterogeneous SPMD computation programs). We can work towards feature parity for the two approaches, as there's substantial implementation (details) overlap - but that's not in scope. The MPMD+SPMD strategy combines advantages from both models. Recent studies with Pathways and JaxPP [4 and 3, above] have extended JAX with MPMD support, though they subsequently focused on other aspects of the architecture stack. I tried to keep the motivation concise, but I can move some of these to the RFC if these provide a better comparative picture. |
It would be great to discuss the runtime requirements of this work in the context of multi-host setting. |
@rpsilva-aws High level questions:
|
🚀 Feature
We propose an accelerator-agnostic, hybrid Single-Program Multiple-Data (SPMD)/Multiple-Program Multiple-Data (MPMD) Pipeline Parallelism implementation in PyTorch XLA. The key objectives are:
In the image below, we see the MPMD+SPMD execution strategy on the RHS. For each pipeline stage/rank, we still require the XLA compiler to partition the single-device program with the needed CC ops, based on the GSPMD sharding annotations. In order to enable heterogeneous use cases and more complex scheduling strategies, we allow users to split their global devices (& pipeline stages) in separate SPMD single-device programs, invoking the MPMD execution strategy across SPMD (local) worlds.
Motivation
Existing SPMD-only Pipeline Parallelism implementations (mainly JAX’s AXLearn [1], Praxis [2]) have fundamental limitations. They require redundant computation across all ranks to maintain SPMD semantics, making pipeline parallelism inefficient and unable to easily support advanced scheduling strategies besides GPipe. As model sizes increase, memory and communication bandwidth become critical bottlenecks. Our hybrid approach, tied with PyTorch’s native distributed pipelining APIs aims to substantially improve performance, drawing from recent research results [3].
There have been MPMD-only RFCs to originally introduce Pipeline Parallelism for PyTorch/XLA, and upstream to PyTorch (#6347). That effort should become relatively easier to achieve as we close some of the overlapping gaps with XLA. It is a subset of this RFC/project, and the design/implementation should not build solutions that are not compatible without SPMD. The end-to-end testing and validation is immediately out-of-scope, but will likely be included - as we don’t expect significant nuances if a user were to use single participating device in their (local) SPMD worlds. Similarly, there is no evident risk in enabling SPMD-only PP applications.
Pitch
Native PyTorch
We extend on top of PyTorch's distributed pipelining (formerly PiPPy) solutions, leveraging FX graphs (https://pytorch.org/docs/stable/fx.html#torch.fx.GraphModule) for the model partitioning, and graph specification and execution. We choose to extend on top of the existing library due to:
The goal is to contribute to PyTorch and close parity to work with XLA devices. The partitioned subgraphs are traced and compiled to executable programs to run on XLA devices. We include needed XLA and SPMD specific invariants in the staging of the P2P communication across ranks, factoring in the device mesh (e.g. sharding specifications).
SPMD
We propose a solution that combines the benefits of both SPMD and MPMD paradigms, providing a foundation for future PP work:
This will substantially improve training efficiency for large language models while maintaining compatibility with the broader PyTorch ecosystem.
Other parallelism techniques
Any other parallelism technique with SPMD should work seamlessly with Pipeline Parallelism. The idea is that each PP rank manages an independent SPMD world, and besides P2P communication, all participating (local) devices execute the same program, and CC ops are still generated by XLA. Users can define and map the most optimal training configurations, e.g. TP mapped over high-bandwidth, and DP/PP over low-bandwidth dimensions.
Localized SPMD
PyTorch/XLA supports localized SPMD execution, allowing individual hosts to run SPMD programs on their local devices independently. This feature provides greater flexibility in multi-host environments, especially for MPMD+SPMD workloads. We implement localized SPMD by decoupling logical device indices from physical device IDs and supporting implicit localization within the XLA graph executor. There is already some preliminary work on this from Siyuan, though for a different use case (#8810).
This implementation allows SPMD programs to reason about their execution using a consistent logical device indexing scheme, while the runtime handles the mapping to physical devices. This should honor the distributed process groups and the OpSharding of all live tensors captured by the XLA graph tracing. We ensure that the implementation can serve heterogeneous applications, as well as the intended scope with pipeline parallelism. This needs to come with the capabilities of defining submeshes across all participating devices.
Furthermore, we start with requiring the same number of (local participating) devices for each program for all SPMD worlds. Hence, we maintain a direct mapping from each sharded data between SPMD worlds. There is no need for CC ops to collect/reduce data that is communicated across pipeline stages. The design/implementation should not make incompatible solutions that would complicate the extension to heterogeneous (local) SPMD worlds, namely gathering/reducing/resharding tensors preceding/following P2P communications.
Process groups
Currently, PyTorch/XLA does not support XLA backend with SPMD, since there is a single replica and CC ops are generated by XLA’s Partitioner. We work on relaxing this constraint, in order to simplify the interfaces with SPMD+MPMD, so we enable both native XLA backend for torch.distributed (P2P comm), and other distributed APIs (DTensor). The former dictating the participating ranks for any given submesh, and the latter serving mainly to simplify/improve the user interface, allowing users to define sub-meshes and abstracting the MPMD+SPMD semantics.
Model initialization
We want to ensure that the model initialization on PyTorch/XLA does not capture the entire model on the device for each stage. In order to avoid this, we want to extend the existing meta device context to only capture the metadata, ensuring that we can use with the XLA device interchangeably. Once a model is traced, we can then send the sharded data to the device. Another consideration to account for is that the RNG seeds are correctly generated for each rank, as each rank should have its own unique RNG seed.
XLA
Currently, XLA has no notion of MPMD+SPMD, and hence requires a single replica group with N participants for the entire SPMD world. It requires the ranks to be ordinal 0-based indexing and the sharding annotations to be indicative of the entire SPMD world for the SPMD partitioner. This requires an orthogonal RFC, and will be evaluated in parallel, namely for PJRT client logical ID mappings, HLO specification proposals, and threaded scheduling for P2P communications.
XLA + SPMD:
[1] https://github.com/apple/axlearn/blob/main/axlearn/common/pipeline.py
[2] https://github.com/google/praxis/blob/main/praxis/layers/pipeline.py
[3] https://arxiv.org/abs/2412.14374
[4] https://arxiv.org/abs/2203.12533
cc: @tengyifei @wconstab @kwen2501
The text was updated successfully, but these errors were encountered: