|
| 1 | +import argparse |
| 2 | +import os |
| 3 | +import torch |
| 4 | +import torch.distributed as dist |
| 5 | +import torch.multiprocessing as mp |
| 6 | +import torch.nn as nn |
| 7 | + |
| 8 | +from torch.distributed._shard import shard_module |
| 9 | +from torch.distributed._shard.sharded_optim import ( |
| 10 | + ShardedOptimizer, |
| 11 | + named_params_with_sharded_tensor, |
| 12 | +) |
| 13 | +from torch.distributed._shard.sharding_plan import ShardingPlan |
| 14 | +from torch.distributed._shard.sharding_spec import ChunkShardingSpec |
| 15 | + |
| 16 | +""" |
| 17 | +This is the script to test Tensor Parallel(TP) on a toy model in a |
| 18 | +Megetron-LM SPMD style. We show an E2E working flow from forward, |
| 19 | +backward and optimization. |
| 20 | +
|
| 21 | +More context about API designs can be found in the design: |
| 22 | +
|
| 23 | +https://github.com/pytorch/pytorch/issues/72138. |
| 24 | +
|
| 25 | +We use the example of two nn layers with an element-wise RELU in between |
| 26 | +to show an example of Megatron-LM, which was proposed in paper: |
| 27 | +
|
| 28 | +https://arxiv.org/abs/1909.08053. |
| 29 | +
|
| 30 | +The basic idea is that we shard the first nn by column and also shard |
| 31 | +the second nn by row so that we don't need the all gather of the result |
| 32 | +of first nn and all scatter of input of the second nn. We can speed up |
| 33 | +the model training by avoiding communications between two layers. |
| 34 | +
|
| 35 | +To shard a nn module, we need to create a sharding spec and plan first, |
| 36 | +and then we shard the module based on it. We will use PyTorch native APIs |
| 37 | +for all of them and this example shows how to use them. |
| 38 | +
|
| 39 | +Additionally, we have built an optimizer for sharded module. We show how |
| 40 | +to use it in the example, too. |
| 41 | +""" |
| 42 | + |
| 43 | + |
| 44 | +def setup(rank, world_size): |
| 45 | + os.environ['MASTER_ADDR'] = 'localhost' |
| 46 | + os.environ['MASTER_PORT'] = '12355' |
| 47 | + |
| 48 | + # initialize the process group |
| 49 | + dist.init_process_group("nccl", rank=rank, world_size=world_size) |
| 50 | + torch.cuda.set_device(rank) |
| 51 | + |
| 52 | +def cleanup(): |
| 53 | + dist.destroy_process_group() |
| 54 | + |
| 55 | + |
| 56 | +class ToyModel(nn.Module): |
| 57 | + def __init__(self): |
| 58 | + super(ToyModel, self).__init__() |
| 59 | + self.net1 = nn.Linear(10, 32) |
| 60 | + self.relu = nn.ReLU() |
| 61 | + self.net2 = nn.Linear(32, 5) |
| 62 | + |
| 63 | + def forward(self, x): |
| 64 | + return self.net2(self.relu(self.net1(x))) |
| 65 | + |
| 66 | + |
| 67 | +def _generate_sharding_spec(world_size): |
| 68 | + """ |
| 69 | + We first need to create a sharding spec for our sharding work. |
| 70 | +
|
| 71 | + For now, we only support sharding on one dimension. So we use |
| 72 | + ``ChunkShardingSpec`` to chunk the size of the given sharding |
| 73 | + dim to equally split length. The behavior is similar to |
| 74 | + `torch.chunk`. |
| 75 | +
|
| 76 | + We also need to create the output sharding spec for the second nn |
| 77 | + because we need to aggregate(reduce) the partial result after the |
| 78 | + second nn layer. So we have a new sharding spec to represent that |
| 79 | + how we store the aggregation result in a new sharded tensor. |
| 80 | + """ |
| 81 | + placements = [f"rank:{idx}/cuda:{idx}" for idx in range(world_size)] |
| 82 | + # Shard the first nn module's weight by dim 0. |
| 83 | + # (nn.Linear transposes the weight internally so dim 0 actually means column) |
| 84 | + colwise_spec = ChunkShardingSpec( |
| 85 | + dim=0, |
| 86 | + placements=placements, |
| 87 | + ) |
| 88 | + # Shard the second nn module's weight by dim 1. |
| 89 | + rowwise_spec = ChunkShardingSpec( |
| 90 | + dim=1, |
| 91 | + placements=placements, |
| 92 | + ) |
| 93 | + # The result from the second nn.linear layer needs aggregation by dim 0. |
| 94 | + output_spec = ChunkShardingSpec( |
| 95 | + dim=0, |
| 96 | + placements=placements, |
| 97 | + ) |
| 98 | + return colwise_spec, rowwise_spec, output_spec |
| 99 | + |
| 100 | + |
| 101 | +def _get_toy_module_optim(module, lr): |
| 102 | + """ |
| 103 | + Creata a optimizer for sharded tensor by using ShardedOptimizer. |
| 104 | + """ |
| 105 | + return ShardedOptimizer( |
| 106 | + dict(named_params_with_sharded_tensor(module)), |
| 107 | + torch.optim.SGD, # SGD is only demo purpose, one can use other optims. |
| 108 | + lr=lr, |
| 109 | + ) |
| 110 | + |
| 111 | + |
| 112 | +def _get_toy_module_sharding_plan(world_size): |
| 113 | + """ |
| 114 | + The idea behind Megatron-LM is that: |
| 115 | + 1. We shard the weight of the first nn by dim 0 (col-wise) |
| 116 | + 2. We shard the weight of the second nn by dim 1 (row-wise) |
| 117 | + 3. We aggregate the partial result of the second nn layer and |
| 118 | + store it as a sharded tensor by dim 0. |
| 119 | + 4. Return the final result on the local shard. |
| 120 | +
|
| 121 | + We then need to create a sharding spec based on it and |
| 122 | + compose a sharding plan on the basis of the spec. |
| 123 | + """ |
| 124 | + colwise_spec, rowwise_spec, output_spec = _generate_sharding_spec(world_size) |
| 125 | + return ShardingPlan( |
| 126 | + # Specify the sharding plan for the component of each module. |
| 127 | + plan={ |
| 128 | + "net1.weight": colwise_spec, |
| 129 | + "net2.weight": rowwise_spec, |
| 130 | + }, |
| 131 | + # Specify the sharding plan for the output of one particular module. |
| 132 | + # e.g., the output of the second nn layer in the example of Megatron-LM. |
| 133 | + output_plan={ |
| 134 | + "net2": output_spec, |
| 135 | + }, |
| 136 | + # Specify to get the tensor stored on the local shard if the output |
| 137 | + # is a sharded tensor. |
| 138 | + return_local_tensor=["net2"], |
| 139 | + ) |
| 140 | + |
| 141 | + |
| 142 | +def demo_tp(rank, args): |
| 143 | + """ |
| 144 | + Main body of the demo of a basic version of tensor parallel by using |
| 145 | + PyTorch native sharded tensor APIs. |
| 146 | + """ |
| 147 | + print(f"Running basic Megatron style TP example on rank {rank}.") |
| 148 | + setup(rank, args.world_size) |
| 149 | + # create a sharding plan based on the given world_size. |
| 150 | + module_sharding_plan = _get_toy_module_sharding_plan( |
| 151 | + args.world_size |
| 152 | + ) |
| 153 | + |
| 154 | + # create model and move it to GPU with id rank |
| 155 | + model = ToyModel().cuda(rank) |
| 156 | + # Shard the module based on created plan. |
| 157 | + shard_module(model, module_sharding_plan) |
| 158 | + # Create a optimizer for the sharded module. |
| 159 | + optimizer = _get_toy_module_optim(model, 0.002) |
| 160 | + |
| 161 | + # Perform a num of iterations of forward/backward |
| 162 | + # and optimizations for the sharded module. |
| 163 | + for _ in range(args.iter_nums): |
| 164 | + inp = torch.rand(20, 10).cuda(rank) |
| 165 | + output = model(inp) |
| 166 | + output.sum().backward() |
| 167 | + optimizer.step() |
| 168 | + |
| 169 | + cleanup() |
| 170 | + |
| 171 | + |
| 172 | +def run_demo(demo_fn, args): |
| 173 | + mp.spawn(demo_fn, |
| 174 | + args=(args,), |
| 175 | + nprocs=args.world_size, |
| 176 | + join=True) |
| 177 | + |
| 178 | + |
| 179 | +if __name__ == "__main__": |
| 180 | + n_gpus = torch.cuda.device_count() |
| 181 | + parser = argparse.ArgumentParser() |
| 182 | + # This is passed in via cmd |
| 183 | + parser.add_argument("--world_size", type=int, default=n_gpus) |
| 184 | + parser.add_argument("--iter_nums", type=int, default=10) |
| 185 | + args = parser.parse_args() |
| 186 | + # The main entry point is called directly without using subprocess |
| 187 | + if n_gpus < 2: |
| 188 | + print("Requires at least 2 GPUs to run.") |
| 189 | + else: |
| 190 | + run_demo(demo_tp, args) |
| 191 | + |
0 commit comments