Skip to content

[RFC] MPMD+SPMD Pipeline Parallelism #9019

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
rpsilva-aws opened this issue Apr 22, 2025 · 6 comments
Open

[RFC] MPMD+SPMD Pipeline Parallelism #9019

rpsilva-aws opened this issue Apr 22, 2025 · 6 comments
Labels
distributed SPMD and other distributed things. enhancement New feature or request

Comments

@rpsilva-aws
Copy link
Collaborator

rpsilva-aws commented Apr 22, 2025

🚀 Feature

We propose an accelerator-agnostic, hybrid Single-Program Multiple-Data (SPMD)/Multiple-Program Multiple-Data (MPMD) Pipeline Parallelism implementation in PyTorch XLA. The key objectives are:

  • Enable efficient model-parallel training for large language models, retaining the same SPMD-model semantics within each pipeline stage, allowing users to write as if in a single large device.
  • Enable heterogeneous programs and advanced scheduling strategies with MPMD (1F1B, Interleaved, ZBV Zero Bubble) across pipeline stages
  • Leverage PyTorch's native distributed pipelining APIs
  • Minimize model changes for users

In the image below, we see the MPMD+SPMD execution strategy on the RHS. For each pipeline stage/rank, we still require the XLA compiler to partition the single-device program with the needed CC ops, based on the GSPMD sharding annotations. In order to enable heterogeneous use cases and more complex scheduling strategies, we allow users to split their global devices (& pipeline stages) in separate SPMD single-device programs, invoking the MPMD execution strategy across SPMD (local) worlds.

Image

Motivation

Existing SPMD-only Pipeline Parallelism implementations (mainly JAX’s AXLearn [1], Praxis [2]) have fundamental limitations. They require redundant computation across all ranks to maintain SPMD semantics, making pipeline parallelism inefficient and unable to easily support advanced scheduling strategies besides GPipe. As model sizes increase, memory and communication bandwidth become critical bottlenecks. Our hybrid approach, tied with PyTorch’s native distributed pipelining APIs aims to substantially improve performance, drawing from recent research results [3].

There have been MPMD-only RFCs to originally introduce Pipeline Parallelism for PyTorch/XLA, and upstream to PyTorch (#6347). That effort should become relatively easier to achieve as we close some of the overlapping gaps with XLA. It is a subset of this RFC/project, and the design/implementation should not build solutions that are not compatible without SPMD. The end-to-end testing and validation is immediately out-of-scope, but will likely be included - as we don’t expect significant nuances if a user were to use single participating device in their (local) SPMD worlds. Similarly, there is no evident risk in enabling SPMD-only PP applications.

Pitch

Native PyTorch

We extend on top of PyTorch's distributed pipelining (formerly PiPPy) solutions, leveraging FX graphs (https://pytorch.org/docs/stable/fx.html#torch.fx.GraphModule) for the model partitioning, and graph specification and execution. We choose to extend on top of the existing library due to:

  • Relatively well established infrastructure on PyTorch, leveraging FX graphs for model partitioning
  • Support for compiler-specific graph (torch.export) or manual model splitting
  • Increasingly larger number of existing scheduling generators
  • Declarative schedule format

The goal is to contribute to PyTorch and close parity to work with XLA devices. The partitioned subgraphs are traced and compiled to executable programs to run on XLA devices. We include needed XLA and SPMD specific invariants in the staging of the P2P communication across ranks, factoring in the device mesh (e.g. sharding specifications).

SPMD

We propose a solution that combines the benefits of both SPMD and MPMD paradigms, providing a foundation for future PP work:

  • Maintain SPMD semantics (TP, FSDP, DP, CP) for each MPMD execution (PP rank)
  • Enable customers to seamlessly use this new paradigm with existing interfaces and scheduling variants
  • Ensure compatibility with existing interfaces
  • Design for extensibility across PyTorch and JAX ecosystems

This will substantially improve training efficiency for large language models while maintaining compatibility with the broader PyTorch ecosystem.

Other parallelism techniques

Any other parallelism technique with SPMD should work seamlessly with Pipeline Parallelism. The idea is that each PP rank manages an independent SPMD world, and besides P2P communication, all participating (local) devices execute the same program, and CC ops are still generated by XLA. Users can define and map the most optimal training configurations, e.g. TP mapped over high-bandwidth, and DP/PP over low-bandwidth dimensions.

Localized SPMD

PyTorch/XLA supports localized SPMD execution, allowing individual hosts to run SPMD programs on their local devices independently. This feature provides greater flexibility in multi-host environments, especially for MPMD+SPMD workloads. We implement localized SPMD by decoupling logical device indices from physical device IDs and supporting implicit localization within the XLA graph executor. There is already some preliminary work on this from Siyuan, though for a different use case (#8810).

  • Logical Ordinal Abstraction: Within each localized SPMD program, devices are addressed using logical ordinals. Under the hood, the device assignment map to the physical devices that the PjRt client manages.
  • Implicit Device Ordinal Mapping: Within each graph executor and device instance, infer the participating devices based on the OpSharding. Invariants are added to ensure that for each SPMD world (replica group), the participating devices are uniform across all XLA tensors.
  • Compilation Configuration: During compilation, the XLA compiler is configured to target only the local devices accessible to the current process (separate RFC).

This implementation allows SPMD programs to reason about their execution using a consistent logical device indexing scheme, while the runtime handles the mapping to physical devices. This should honor the distributed process groups and the OpSharding of all live tensors captured by the XLA graph tracing. We ensure that the implementation can serve heterogeneous applications, as well as the intended scope with pipeline parallelism. This needs to come with the capabilities of defining submeshes across all participating devices.

Furthermore, we start with requiring the same number of (local participating) devices for each program for all SPMD worlds. Hence, we maintain a direct mapping from each sharded data between SPMD worlds. There is no need for CC ops to collect/reduce data that is communicated across pipeline stages. The design/implementation should not make incompatible solutions that would complicate the extension to heterogeneous (local) SPMD worlds, namely gathering/reducing/resharding tensors preceding/following P2P communications.

Process groups

Currently, PyTorch/XLA does not support XLA backend with SPMD, since there is a single replica and CC ops are generated by XLA’s Partitioner. We work on relaxing this constraint, in order to simplify the interfaces with SPMD+MPMD, so we enable both native XLA backend for torch.distributed (P2P comm), and other distributed APIs (DTensor). The former dictating the participating ranks for any given submesh, and the latter serving mainly to simplify/improve the user interface, allowing users to define sub-meshes and abstracting the MPMD+SPMD semantics.

Model initialization

We want to ensure that the model initialization on PyTorch/XLA does not capture the entire model on the device for each stage. In order to avoid this, we want to extend the existing meta device context to only capture the metadata, ensuring that we can use with the XLA device interchangeably. Once a model is traced, we can then send the sharded data to the device. Another consideration to account for is that the RNG seeds are correctly generated for each rank, as each rank should have its own unique RNG seed.

XLA

Currently, XLA has no notion of MPMD+SPMD, and hence requires a single replica group with N participants for the entire SPMD world. It requires the ranks to be ordinal 0-based indexing and the sharding annotations to be indicative of the entire SPMD world for the SPMD partitioner. This requires an orthogonal RFC, and will be evaluated in parallel, namely for PJRT client logical ID mappings, HLO specification proposals, and threaded scheduling for P2P communications.

XLA + SPMD:

import os
import sys
from typing import Optional

import numpy as np
import torch
from torch import nn
import torch.optim as optim

+ import torch_xla.core.xla_model as xm
+ import torch_xla.runtime as xr
+ import torch_xla.distributed.spmd as xs
+ import torch_xla.distributed.xla_backend
+ import torch_xla.distributed.parallel_loader as pl

import args_parse
import torch.distributed as dist
from pippy import pipeline, SplitPoint, ScheduleGPipe, PipelineStage

MODEL_OPTS = {
    '--input_dim': {
        'type': int,
        'default': 16834,
    },
    '--train_dataset_len': {
        'type': int,
        'default': 1024 * 8,
    },
    '--pipeline_chunks': {
        'type': int,
        'default': 4,
    }
}

FLAGS = {}

+ xr.use_spmd()

class SimpleLinear(nn.Module):
  NUM_CLASSES = 3

  def __init__(self):
    super().__init__()
    # Instead of Sequential, define layers separately for easier split points
    self.layer0 = nn.Linear(FLAGS.input_dim, FLAGS.input_dim // 2)
    self.relu = nn.ReLU()
    self.layer1 = nn.Linear(FLAGS.input_dim // 2, 3)
    self.layer2 = nn.Linear(3, self.NUM_CLASSES)

  def forward(self, x):
     x = self.layer0(x)
     x = self.relu(x)
     x = self.layer1(x)
     x = self.layer2(x)
    return x


def train():
+ # Torchrun is needed for Pipeline Parallelism by default. Generally, we
+ # don't need it for SPMD, and we could rely on `process_index` and
+ # `addressable_runtime_device_count` from PjRT runtime. However, it would
+ # be needed if we have multiple SPMD worlds within the same physical machine.
+ # Hence, we retain the requirement, and can relax it later on.
+ rank = int(os.environ["RANK"])  # or xr.process_index()
+ world_size = int(os.environ["WORLD_SIZE"])  # or xr.addressable_runtime_device_count()
- rank = int(os.environ["RANK"])
- world_size = int(os.environ["WORLD_SIZE"])
  chunks = FLAGS.pipeline_chunks
 
- if torch.cuda.is_available():
-     device = torch.device(f"cuda:{rank % torch.cuda.device_count()}")
- else:
-     device = torch.device("cpu")
+ # Use XLA device
+ device = xm.xla_device()

  print(f"Rank {rank} using device {device}")

+ # (Preferred) Leverage the DTensor/DeviceMesh variants for a more seamless
+ # user interface with submeshes.
+ # -----
+ global_mesh = init_device_mesh("xla", (chunks, num_devices, 1),
+                                mesh_dim_names=("pp", "data", "model"))
+ local_mesh = global_mesh["data", "model"]
+ # -----
+ # Alternatively:
+ # -----
+ num_devices = xr.global_runtime_device_count()
+
+ # Global submesh
+ global_mesh_shape = (chunks, num_local_devices, 1)
+ global_mesh = Mesh(np.arange(num_devices), global_mesh_shape, ("pp", "data", "model"))
+
+ # Local submesh
+ num_local_devices = xr.addressable_runtime_device_count()
+ device_id_start = rank * num_local_devices
+ local_device_ids = np.arange(device_id_start, device_id_start + num_local_devices)
+ local_mesh_shape = global_mesh_shape[1:]
+ local_mesh = Mesh(local_device_ids, local_mesh_shape, ("data", "model"))
+ # -----

  # Initialize process group
- dist.init_process_group(rank=rank, world_size=world_size)
+ dist.init_process_group(
+     backend="xla",
+     init_method="xla://",
+     rank=rank,
+     world_size=world_size
+ )
+
  
  torch.manual_seed(42)
  model = SimpleLinear().to(device)
  
+ # Shard the model weights as needed:
+ # parallelize_model(model, local_mesh)
  
  # Define split points for pipeline parallelism
  split_spec = {
    "layer0": SplitPoint.END,
    "layer1": SplitPoint.END,
  }
  
  # Create a sample input for the pipeline
  batch_size = FLAGS.batch_size
  example_input = torch.randn(batch_size, FLAGS.input_dim, device=device)
  
  # Create the pipeline and respective stage for the rank.
  pipe = pipeline(model, chunks, example_args=(example_input,), split_spec=split_spec)
  stage = PipelineStage(pipe, rank, device)
  schedule = ScheduleGPipe(stage, chunks)
  
  # Training loop
  losses = []
  loss_fn = nn.CrossEntropyLoss()
  optimizer = optim.SGD(model.parameters(), lr=FLAGS.lr)
  
  for epoch in range(FLAGS.num_epochs):
    for step, (data, target) in enumerate(data_generator()):
      if rank == 0:
+       xs.mark_sharding(data, local_mesh, ('data', 'model')
+       # or distribute_tensor(data, local_mesh, [Shard(0), Shard(1)])
        schedule.step(data)
        optimizer.zero_grad()
      else:
        output = schedule.step()
        
        # Only the last rank computes loss and does backward
        if rank == world_size - 1:
          loss = loss_fn(output, target)
          losses.append(loss.clone().detach())
          loss.backward()
          optimizer.step()
          
          if step % FLAGS.log_steps == 0:
            print(f"Epoch {epoch} step {step} loss {loss}")
+      xm.mark_step()

def train():
  default_config = {
      'batch_size': 128,
      'num_epochs': 1,
      'lr': 0.1,
      'log_steps': 8,
      'opts': MODEL_OPTS.items()
  }

  global FLAGS
  FLAGS = args_parse.parse_common_options(**default_config)
  print('Start training loop...')
  train()
  dist.destroy_process_group()

[1] https://github.com/apple/axlearn/blob/main/axlearn/common/pipeline.py
[2] https://github.com/google/praxis/blob/main/praxis/layers/pipeline.py
[3] https://arxiv.org/abs/2412.14374
[4] https://arxiv.org/abs/2203.12533

cc: @tengyifei @wconstab @kwen2501

@rpsilva-aws rpsilva-aws added the enhancement New feature or request label Apr 22, 2025
@rpsilva-aws rpsilva-aws changed the title MPMD+SPMD Pipeline Parallelism RFC [RFC] MPMD+SPMD Pipeline Parallelism Apr 22, 2025
@miladm miladm added the distributed SPMD and other distributed things. label Apr 22, 2025
@miladm
Copy link
Collaborator

miladm commented Apr 22, 2025

Thank you for sharing this RFC!

We are reviewing it. cc @lsy323 @tengyifei @haifeng-jin @bhavya01 @pgmoka @GleasonK

@rpsilva-aws
Copy link
Collaborator Author

rpsilva-aws commented Apr 22, 2025

We had minimally prototyped it with XLA + SPMD - mimicking communication across heterogeneous SPMD worlds, minus the native distributed pipelining APIs:

  • SPMD localization (similar functionality to Siyuan's local CR)
  • Inferring SPMD localization from all live tensors, abiding by the local SPMD mesh
  • Enable distributed process groups for XLA + SPMD
  • Injecting P2P communication as custom ops with the XLA backend
  • Manually transposing P2P custom ops to the appropriate CC op after the SPMD partitioner (custom PJRT pass)

The latter 2 were Neuron-specific, but I'll adapt and share as part of the RFC as well as suggested by Siyuan.

@zpcore
Copy link
Collaborator

zpcore commented Apr 23, 2025

Thanks for the RFC! This seems quite a lot of work and many details need to be filled in, e.g., how to support profiling.

Got a few questions:

There is no need for CC ops to collect/reduce data that is communicated across pipeline stages.
How do you plan to sync input/output between different SPMD stages?

On the top level. I don't quite understand the benefit of the SPMD+MPMD solution in the RFC compared with the MPMD solution. I think leave everything to the XLA compiler should be much easier and more efficient for overlapping communication overhead. Can you elaborate more on the motivation? Thanks!

@rpsilva-aws
Copy link
Collaborator Author

rpsilva-aws commented Apr 23, 2025

Thanks @zpcore!

How do you plan to sync input/output between different SPMD stages?

In the RFC above, we started focusing on local SPMD worlds that have same number of participants and localized SPMD meshes. Communication occurs across SPMD local ranks: For example, with 2 stages and 8 devices each, we'd have P2P communication as [[0, 8], [1, 9], ..., [7, 15]], transferring sharded data across stages (with receiving rank using device data IR placeholders). We're also careful not to make decisions that would prevent future extension to heterogeneous local SPMD worlds (different meshes or participant counts) - which would require synchronization before or after stages.

On the top level question, capturing all different variants for the wider picture:

  • MPMD-only: We sacrifice the benefits of the GSPMD model strategy, as users must compromise to enable advanced scheduling strategies. This approach requires using MPMD for all parallelism strategies (TP, DP, SP, CP, FSDP), whereas we want users to still benefit from lightweight annotations with XLA's partitioner handling the communication collective placements for the SPMD worlds. Users would need to manually add all communication collective placements, plus having to maintain the 1 process per device model.

  • SPMD-only: Advanced scheduling strategies become unavailable, as we're limited to the GPipe variant. Pipeline parallelism is fundamentally challenging with this model strategy. GPipe requires homogeneous stages and padding the entire model computation, adding superfluous calculations to each rank - sometimes non-negligibly (LMHead, Embedding, etc). OpenXLA's potential work on reducing superfluous computation may help, but the requirement remains problematic. Recent presentations showed execution/communication overlap optimizations in OpenXLA for GPU with SPMD-only GPipe, achieving nice 1.67x and 1.14x improvements for Llama3.1 70B and 405B over GPipe's prior baseline (sequential computation and communication). However, these are limited to this basic scheduling strategy with suboptimal memory/performance characteristics.

We don't limit opportunities to optimize communication overlaps in the XLA compiler, as we can leverage insights from both model strategies. This will be addressed in OpenXLA RFC follow-ups, with clear separation between SPMD and MPMD communication collective operations, which can conceptually help with execution sequence dependencies, communication overlapping, and optimal computation scheduling. This work can be adapted when supporting MPMD+SPMD strategies, even without pipeline parallelism (e.g., heterogeneous SPMD computation programs). We can work towards feature parity for the two approaches, as there's substantial implementation (details) overlap - but that's not in scope.

The MPMD+SPMD strategy combines advantages from both models. Recent studies with Pathways and JaxPP [4 and 3, above] have extended JAX with MPMD support, though they subsequently focused on other aspects of the architecture stack. I tried to keep the motivation concise, but I can move some of these to the RFC if these provide a better comparative picture.

@miladm
Copy link
Collaborator

miladm commented Apr 24, 2025

It would be great to discuss the runtime requirements of this work in the context of multi-host setting.

@tengyifei
Copy link
Collaborator

@rpsilva-aws High level questions:

  • Are you proposing to have PyTorch/XLA capture the collectives used for Pipeline Parallelism into HLO ops? How would you run those ops?

  • What is "P2P" communication? Taken literally, that means point-to-point communication between two devices. But I'm not sure that's what you meant?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
distributed SPMD and other distributed things. enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

4 participants