Skip to content

Commit 2bf23f1

Browse files
authored
[PT-D] Add an example for Megatron-LM style example (#1008)
ghstack-source-id: 74b14ff Pull Request resolved: #1006
1 parent 09ae1f9 commit 2bf23f1

File tree

3 files changed

+212
-0
lines changed

3 files changed

+212
-0
lines changed

distributed/sharded_tensor/README.md

+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# PyTorch Sharder for distributed training, Tensor Parallel Example
2+
3+
This example demonstrates SPMD Megatron-LM style tensor parallel by using
4+
PyTorch native sharding APIs, which include:
5+
6+
1. Sharding spec/plan and high-level APIs for module-level sharding.
7+
2. Model agnostic ops for `ShardedTensor`, such as `Linear` and `RELU`.
8+
3. A E2E demo of tensor parallel for a given toy model (Forward/backward + optimization).
9+
4. API to optimize parameters when they are `ShardedTensor`s.
10+
11+
12+
More details about the design can be found:
13+
https://github.com/pytorch/pytorch/issues/72138
14+
15+
16+
```
17+
pip install -r requirements.txt
18+
python main.py
19+
```
20+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
torch>=1.12.0
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
import argparse
2+
import os
3+
import torch
4+
import torch.distributed as dist
5+
import torch.multiprocessing as mp
6+
import torch.nn as nn
7+
8+
from torch.distributed._shard import shard_module
9+
from torch.distributed._shard.sharded_optim import (
10+
ShardedOptimizer,
11+
named_params_with_sharded_tensor,
12+
)
13+
from torch.distributed._shard.sharding_plan import ShardingPlan
14+
from torch.distributed._shard.sharding_spec import ChunkShardingSpec
15+
16+
"""
17+
This is the script to test Tensor Parallel(TP) on a toy model in a
18+
Megetron-LM SPMD style. We show an E2E working flow from forward,
19+
backward and optimization.
20+
21+
More context about API designs can be found in the design:
22+
23+
https://github.com/pytorch/pytorch/issues/72138.
24+
25+
We use the example of two nn layers with an element-wise RELU in between
26+
to show an example of Megatron-LM, which was proposed in paper:
27+
28+
https://arxiv.org/abs/1909.08053.
29+
30+
The basic idea is that we shard the first nn by column and also shard
31+
the second nn by row so that we don't need the all gather of the result
32+
of first nn and all scatter of input of the second nn. We can speed up
33+
the model training by avoiding communications between two layers.
34+
35+
To shard a nn module, we need to create a sharding spec and plan first,
36+
and then we shard the module based on it. We will use PyTorch native APIs
37+
for all of them and this example shows how to use them.
38+
39+
Additionally, we have built an optimizer for sharded module. We show how
40+
to use it in the example, too.
41+
"""
42+
43+
44+
def setup(rank, world_size):
45+
os.environ['MASTER_ADDR'] = 'localhost'
46+
os.environ['MASTER_PORT'] = '12355'
47+
48+
# initialize the process group
49+
dist.init_process_group("nccl", rank=rank, world_size=world_size)
50+
torch.cuda.set_device(rank)
51+
52+
def cleanup():
53+
dist.destroy_process_group()
54+
55+
56+
class ToyModel(nn.Module):
57+
def __init__(self):
58+
super(ToyModel, self).__init__()
59+
self.net1 = nn.Linear(10, 32)
60+
self.relu = nn.ReLU()
61+
self.net2 = nn.Linear(32, 5)
62+
63+
def forward(self, x):
64+
return self.net2(self.relu(self.net1(x)))
65+
66+
67+
def _generate_sharding_spec(world_size):
68+
"""
69+
We first need to create a sharding spec for our sharding work.
70+
71+
For now, we only support sharding on one dimension. So we use
72+
``ChunkShardingSpec`` to chunk the size of the given sharding
73+
dim to equally split length. The behavior is similar to
74+
`torch.chunk`.
75+
76+
We also need to create the output sharding spec for the second nn
77+
because we need to aggregate(reduce) the partial result after the
78+
second nn layer. So we have a new sharding spec to represent that
79+
how we store the aggregation result in a new sharded tensor.
80+
"""
81+
placements = [f"rank:{idx}/cuda:{idx}" for idx in range(world_size)]
82+
# Shard the first nn module's weight by dim 0.
83+
# (nn.Linear transposes the weight internally so dim 0 actually means column)
84+
colwise_spec = ChunkShardingSpec(
85+
dim=0,
86+
placements=placements,
87+
)
88+
# Shard the second nn module's weight by dim 1.
89+
rowwise_spec = ChunkShardingSpec(
90+
dim=1,
91+
placements=placements,
92+
)
93+
# The result from the second nn.linear layer needs aggregation by dim 0.
94+
output_spec = ChunkShardingSpec(
95+
dim=0,
96+
placements=placements,
97+
)
98+
return colwise_spec, rowwise_spec, output_spec
99+
100+
101+
def _get_toy_module_optim(module, lr):
102+
"""
103+
Creata a optimizer for sharded tensor by using ShardedOptimizer.
104+
"""
105+
return ShardedOptimizer(
106+
dict(named_params_with_sharded_tensor(module)),
107+
torch.optim.SGD, # SGD is only demo purpose, one can use other optims.
108+
lr=lr,
109+
)
110+
111+
112+
def _get_toy_module_sharding_plan(world_size):
113+
"""
114+
The idea behind Megatron-LM is that:
115+
1. We shard the weight of the first nn by dim 0 (col-wise)
116+
2. We shard the weight of the second nn by dim 1 (row-wise)
117+
3. We aggregate the partial result of the second nn layer and
118+
store it as a sharded tensor by dim 0.
119+
4. Return the final result on the local shard.
120+
121+
We then need to create a sharding spec based on it and
122+
compose a sharding plan on the basis of the spec.
123+
"""
124+
colwise_spec, rowwise_spec, output_spec = _generate_sharding_spec(world_size)
125+
return ShardingPlan(
126+
# Specify the sharding plan for the component of each module.
127+
plan={
128+
"net1.weight": colwise_spec,
129+
"net2.weight": rowwise_spec,
130+
},
131+
# Specify the sharding plan for the output of one particular module.
132+
# e.g., the output of the second nn layer in the example of Megatron-LM.
133+
output_plan={
134+
"net2": output_spec,
135+
},
136+
# Specify to get the tensor stored on the local shard if the output
137+
# is a sharded tensor.
138+
return_local_tensor=["net2"],
139+
)
140+
141+
142+
def demo_tp(rank, args):
143+
"""
144+
Main body of the demo of a basic version of tensor parallel by using
145+
PyTorch native sharded tensor APIs.
146+
"""
147+
print(f"Running basic Megatron style TP example on rank {rank}.")
148+
setup(rank, args.world_size)
149+
# create a sharding plan based on the given world_size.
150+
module_sharding_plan = _get_toy_module_sharding_plan(
151+
args.world_size
152+
)
153+
154+
# create model and move it to GPU with id rank
155+
model = ToyModel().cuda(rank)
156+
# Shard the module based on created plan.
157+
shard_module(model, module_sharding_plan)
158+
# Create a optimizer for the sharded module.
159+
optimizer = _get_toy_module_optim(model, 0.002)
160+
161+
# Perform a num of iterations of forward/backward
162+
# and optimizations for the sharded module.
163+
for _ in range(args.iter_nums):
164+
inp = torch.rand(20, 10).cuda(rank)
165+
output = model(inp)
166+
output.sum().backward()
167+
optimizer.step()
168+
169+
cleanup()
170+
171+
172+
def run_demo(demo_fn, args):
173+
mp.spawn(demo_fn,
174+
args=(args,),
175+
nprocs=args.world_size,
176+
join=True)
177+
178+
179+
if __name__ == "__main__":
180+
n_gpus = torch.cuda.device_count()
181+
parser = argparse.ArgumentParser()
182+
# This is passed in via cmd
183+
parser.add_argument("--world_size", type=int, default=n_gpus)
184+
parser.add_argument("--iter_nums", type=int, default=10)
185+
args = parser.parse_args()
186+
# The main entry point is called directly without using subprocess
187+
if n_gpus < 2:
188+
print("Requires at least 2 GPUs to run.")
189+
else:
190+
run_demo(demo_tp, args)
191+

0 commit comments

Comments
 (0)