forked from feifeibear/long-context-attention
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Megatron-DeepSpeed.patch
327 lines (314 loc) · 11.7 KB
/
Megatron-DeepSpeed.patch
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
From 01eb56347633f5f016ed9c9aa62b3e49d7cd37fa Mon Sep 17 00:00:00 2001
From: root <[email protected]>
Date: Fri, 19 Apr 2024 06:19:51 +0000
Subject: [PATCH 1/2] [cp] add hybrid context parallel
---
megatron/arguments.py | 2 +
megatron/core/parallel_state.py | 20 ++++
megatron/initialize.py | 3 +-
megatron/model/transformer.py | 14 ++-
start_gpt.sh | 176 ++++++++++++++++++++++++++++++++
5 files changed, 211 insertions(+), 4 deletions(-)
create mode 100755 start_gpt.sh
diff --git a/megatron/arguments.py b/megatron/arguments.py
index 631d4b1..a91db90 100644
--- a/megatron/arguments.py
+++ b/megatron/arguments.py
@@ -951,6 +951,8 @@ def _add_training_args(parser):
help='Enable Megatron-LM\'s sequence parallel optimization.')
group.add_argument('--ds-sequence-parallel-size', type=int, default=1,
help='Enable DeepSpeed\'s sequence parallel. Cannot be combined with "--sequence-parallel", which enables Megatron-LM\'s sequence parallel.')
+ group.add_argument('--ds-ring-sequence-parallel-size', type=int, default=1,
+ help='Ring sequenceparallel degree.')
group.add_argument('--force-ds-sequence-parallel', action='store_true',
help='use DeepSpeed sequence parallelism regardless of sequence parallel size.')
group.add_argument('--no-gradient-accumulation-fusion',
diff --git a/megatron/core/parallel_state.py b/megatron/core/parallel_state.py
index 819760e..2f2aad9 100644
--- a/megatron/core/parallel_state.py
+++ b/megatron/core/parallel_state.py
@@ -7,6 +7,12 @@ from typing import Optional
from .utils import GlobalMemoryBuffer
+try:
+ from yunchang import set_seq_parallel_pg
+ from yunchang.globals import PROCESS_GROUP as YUNCHANG_PROCESS_GROUP
+except ImportError:
+ set_seq_parallel_pg = None
+
# Intra-layer model parallel group that the current rank belongs to.
_TENSOR_MODEL_PARALLEL_GROUP = None
# Inter-layer model parallel group that the current rank belongs to.
@@ -70,6 +76,7 @@ def initialize_model_parallel(
pipeline_model_parallel_split_rank: Optional[int] = None,
use_fp8: bool = False,
use_distributed_optimizer: bool = False,
+ ring_parallel_size: int =1,
) -> None:
"""Initialize model data parallel groups.
@@ -213,6 +220,15 @@ def initialize_model_parallel(
if rank in ranks:
_SEQUENCE_PARALLEL_GROUP = group
+ ring_degree = ring_parallel_size
+ ulysse_degree = sequence_parallel_size // ring_parallel_size
+ assert sequence_parallel_size % ulysse_degree == 0, f"sequence_parallel_size {sequence_parallel_size} is not divisible by ulysse_degree {ulysse_degree}"
+ assert sequence_parallel_size == ring_degree * ulysse_degree, f"sequence_parallel_size {sequence_parallel_size} is not equal to ring_degree {ring_degree} * ulysse_degree {ulysse_degree}"
+ if set_seq_parallel_pg is not None:
+ set_seq_parallel_pg(ulysse_degree, ring_degree, rank, world_size)
+ else:
+ print("set_seq_parallel_pg is not available")
+
# Build the sequence data parallel groups.
global _SEQUENCE_DATA_PARALLEL_GROUP
assert _SEQUENCE_DATA_PARALLEL_GROUP is None, \
@@ -445,6 +461,10 @@ def get_model_parallel_world_size():
assert get_pipeline_model_parallel_world_size() == 1, "legacy get_model_parallel_world_size is only supported if PP is disabled"
return get_tensor_model_parallel_world_size()
+def get_ulysses_sequence_parallel_world_size():
+ """Return world size for the ulysses sequence parallel group."""
+ return torch.distributed.get_world_size(group=YUNCHANG_PROCESS_GROUP.ULYSSES_PG)
+
def get_sequence_parallel_world_size():
"""Return world size for the sequence parallel group."""
global _SEQUENCE_PARALLEL_WORLD_SIZE
diff --git a/megatron/initialize.py b/megatron/initialize.py
index 31f26c5..8b021be 100644
--- a/megatron/initialize.py
+++ b/megatron/initialize.py
@@ -244,7 +244,8 @@ def _initialize_distributed():
args.ds_sequence_parallel_size,
args.virtual_pipeline_model_parallel_size,
args.pipeline_model_parallel_split_rank,
- use_distributed_optimizer=args.use_distributed_optimizer)
+ use_distributed_optimizer=args.use_distributed_optimizer,
+ ring_parallel_size = args.ds_ring_sequence_parallel_size)
if args.rank == 0:
print(f'> initialized tensor model parallel with size '
f'{mpu.get_tensor_model_parallel_world_size()}')
diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py
index e75f13a..bcac2fb 100644
--- a/megatron/model/transformer.py
+++ b/megatron/model/transformer.py
@@ -50,6 +50,10 @@ except ImportError:
FlashAttentionBuilder = get_accelerator().get_op_builder("FlashAttentionBuilder")
flash_attn_builder = None
+try:
+ from yunchang import UlyssesAttention, LongContextAttention, set_seq_parallel_pg
+except ImportError:
+ UlyssesAttention = None
""" We use the following notation throughout this file:
h: hidden size
@@ -597,8 +601,11 @@ class ParallelAttention(MegatronModule):
or args.force_ds_sequence_parallel
if self.enable_ds_sequence_parallel:
assert dist_attn_supported, 'Distributed attention is not supported in this DeepSpeed version'
- assert args.num_attention_heads % parallel_state.get_sequence_parallel_world_size() == 0
- self.dist_attn = DistributedAttention(local_attn, parallel_state.get_sequence_parallel_group())
+ # assert args.num_attention_heads % parallel_state.get_sequence_parallel_world_size() == 0
+ # self.dist_attn = DistributedAttention(local_attn, parallel_state.get_sequence_parallel_group())
+ assert args.num_attention_heads % parallel_state.get_ulysses_sequence_parallel_world_size() == 0, \
+ f"Number of attention heads {args.num_attention_heads} must be divisible by the number of Ulysses sequence parallel partitions {parallel_state.get_ulysses_sequence_parallel_world_size()}"
+ self.dist_attn = LongContextAttention()
else:
if self.use_flash_attn:
self.core_attention_flash = local_attn
@@ -616,7 +623,6 @@ class ParallelAttention(MegatronModule):
input_is_parallel=True,
skip_bias_add=True)
-
def _checkpointed_attention_forward(self, query_layer, key_layer,
value_layer, attention_mask,
rotary_pos_emb=None):
@@ -808,11 +814,13 @@ class ParallelAttention(MegatronModule):
query_layer, key_layer, value_layer = [rearrange(x, 's b ... -> b s ...').contiguous()
for x in (query_layer, key_layer, value_layer)]
+ # print(f"fjr-debug use fa query_layer {query_layer.shape}")
context_layer = self.dist_attn(query_layer, key_layer, value_layer)
if not self.use_flash_attn_triton:
context_layer = rearrange(context_layer, 'b s h d -> s b (h d)').contiguous()
else:
+ # print(f"fjr-debug not use fa query_layer {query_layer.shape}")
context_layer = self.dist_attn(query_layer, key_layer, value_layer, attention_mask)
else:
if self.use_flash_attn:
diff --git a/start_gpt.sh b/start_gpt.sh
new file mode 100755
index 0000000..17efb7e
--- /dev/null
+++ b/start_gpt.sh
@@ -0,0 +1,176 @@
+#! /bin/bash
+
+####################################################
+#
+# usage:
+# bash start.sh <model_size> <master_addr> <node_num> <rank>
+#
+# supported model size: {7, 13, 175}
+#
+####################################################
+
+
+# env var
+export CUDA_DEVICE_MAX_CONNECTIONS=1
+
+# nccl settings
+#export NCCL_DEBUG=INFO
+export NCCL_SOCKET_IFNAME=eth0
+export NCCL_IB_GID_INDEX=3
+export NCCL_IB_DISABLE=0
+export NCCL_NET_GDR_LEVEL=2
+export NCCL_IB_QPS_PER_CONNECTION=4
+export NCCL_IB_TC=160
+export NCCL_IB_TIMEOUT=22
+
+export GLOO_SOCKET_IFNAME=eth0
+
+export PYTHONPATH=$PWD:$PYTHONPATH
+
+# data settings
+BASE_DATA_PATH=/data/datasets/gpt-data/
+DATA_PATH=$BASE_DATA_PATH/my-gpt2_text_document
+VOCAB_FILE=$BASE_DATA_PATH/gpt2-vocab.json
+MERGE_FILE=$BASE_DATA_PATH/gpt2-merges.txt
+CHECKPOINT_PATH=./output/
+
+
+ZERO_STAGE=3
+
+# create DS config
+DS_CONFIG=ds_config.json
+DATA_TYPE=
+if [ ${ZERO_STAGE} -eq 1 ]; then
+ DATA_TYPE="
+ \"data_types\":{
+ \"grad_accum_dtype\":\"fp32\"
+ },
+ "
+fi
+
+
+# model settings
+SEQ_LEN=8192
+MAX_SEQ_LEN=8192
+MODEL_SIZE=${1:-7}
+if [ $MODEL_SIZE == "7" ]; then
+ NUM_LAYERS=32
+ HIDDEN_SIZE=4096
+ NUM_ATTN_HEADS=32
+ MICRO_BATCH_SIZE=1
+ TP=1
+ PP=1
+ CP=4
+ RCP=2
+ MICRO_BATCH_NUM=32
+elif [ $MODEL_SIZE == "13" ]; then
+ NUM_LAYERS=40
+ HIDDEN_SIZE=5120
+ NUM_ATTN_HEADS=40
+ MICRO_BATCH_SIZE=1
+ TP=1
+ PP=2
+ MICRO_BATCH_NUM=64
+elif [ $MODEL_SIZE == "175" ]; then
+ NUM_LAYERS=96
+ HIDDEN_SIZE=12288
+ NUM_ATTN_HEADS=96
+ MICRO_BATCH_SIZE=1
+ TP=8
+ PP=4
+ MICRO_BATCH_NUM=256
+else
+ echo "ERROR: Please supplement new model configuration to test!"
+ exit -1
+fi
+
+#fp8 settings
+ENABLE_FP8=false
+if [ $ENABLE_FP8 == "true" ]; then
+ FP8_OPTS="--transformer-impl transformer_engine --fp8-format hybrid "
+ DT="fp8"
+else
+ FP8_OPTS=""
+ DT="bf16"
+fi
+
+# node settings
+MASTER_ADDR=${2:-localhost}
+MASTER_PORT=6000
+NNODES=${3:-1}
+NODE_RANK=${4:-0}
+GPUS_PER_NODE=8
+WORLD_SIZE=$(( $GPUS_PER_NODE * $NNODES ))
+DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT"
+
+DP=$(( $WORLD_SIZE / $TP / $PP / $CP))
+GLOBAL_BATCH_SIZE=$(( $DP * $MICRO_BATCH_SIZE * $MICRO_BATCH_NUM ))
+
+
+cat << EOT > $DS_CONFIG
+{
+ "train_batch_size" : $GLOBAL_BATCH_SIZE,
+ "train_micro_batch_size_per_gpu": $MICRO_BATCH_SIZE,
+ "steps_per_print": 1,
+ "gradient_clipping": 1.0,
+ "zero_optimization": {
+ "stage": $ZERO_STAGE
+ },
+ "bf16": {
+ "enabled": true,
+ "accumulate_grads_via_hooks": true
+ },
+ "fp16": {"enabled": false},
+ "wall_clock_breakdown": false
+}
+EOT
+
+
+
+CMD="torchrun $DISTRIBUTED_ARGS \
+ pretrain_gpt.py \
+ --tensor-model-parallel-size $TP \
+ --pipeline-model-parallel-size $PP \
+ --ds-sequence-parallel-size $CP \
+ --ds-ring-sequence-parallel-size $RCP \
+ --num-layers $NUM_LAYERS \
+ --hidden-size $HIDDEN_SIZE \
+ --num-attention-heads $NUM_ATTN_HEADS \
+ --micro-batch-size $MICRO_BATCH_SIZE \
+ --global-batch-size $GLOBAL_BATCH_SIZE \
+ --seq-length $SEQ_LEN \
+ --max-position-embeddings $SEQ_LEN \
+ --train-iters 500 \
+ --lr-decay-iters 320000 \
+ --save $CHECKPOINT_PATH \
+ --data-path $DATA_PATH \
+ --vocab-file $VOCAB_FILE \
+ --merge-file $MERGE_FILE \
+ --split 949,50,1 \
+ --distributed-backend nccl \
+ --lr 0.00015 \
+ --lr-decay-style cosine \
+ --min-lr 1.0e-5 \
+ --weight-decay 1e-2 \
+ --clip-grad 1.0 \
+ --lr-warmup-fraction .01 \
+ --log-interval 1 \
+ --save-interval 10000 \
+ --eval-interval 10000 \
+ --exit-interval 10000 \
+ --eval-iters 1000 \
+ --use-flash-attn-v2 \
+ --recompute-activations \
+ --use-distributed-optimizer \
+ --bf16 \
+ $FP8_OPTS \
+ --deepspeed \
+ --deepspeed_config $DS_CONFIG \
+ --zero-stage=$ZERO_STAGE \
+ --no-pipeline-parallel \
+ "
+
+echo ${CMD} 2>&1 | tee megatron_gpt-${MODEL_SIZE}B_tp${TP}_pp${PP}_dp${DP}_mb${MICRO_BATCH_SIZE}_gb${GLOBAL_BATCH_SIZE}_${DT}.log
+eval ${CMD} 2>&1 | tee -a megatron_gpt-${MODEL_SIZE}B_tp${TP}_pp${PP}_dp${DP}_mb${MICRO_BATCH_SIZE}_gb${GLOBAL_BATCH_SIZE}_${DT}.log
--
2.34.1