Skip to content

Commit f4688c5

Browse files
committed
test version
1 parent 21704bc commit f4688c5

File tree

10 files changed

+225
-52
lines changed

10 files changed

+225
-52
lines changed

configs/7B_internlm2.py

+4
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@
3838
async_upload=True, # async ckpt upload. (only work for boto3 ckpt)
3939
async_upload_tmp_folder="/dev/shm/internlm_tmp_ckpt/", # path for temporarily files during asynchronous upload.
4040
oss_snapshot_freq=int(CHECKPOINT_EVERY / 2), # snapshot ckpt save frequency.
41+
# If enable_save_ckpt=True, metadata will be automatically generated.
42+
# If generate_meta_data.enable=True, metadata can be independently generated in generate_meta_data.path during initialization.
43+
# When only need to generate metadata, please set generate_meta_data to do it.
44+
generate_meta_data=dict(enable=False, path='./')
4145
)
4246

4347
TRAIN_FOLDER = None

internlm/checkpoint/checkpoint_manager.py

+6
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,7 @@ def __init__(
229229
model_config=None,
230230
model_config_file=None,
231231
feishu_address=None,
232+
meta_data=None,
232233
) -> None:
233234
"""
234235
CheckpointManager is used to decide when to store ckpt. If it is an asynchronous
@@ -247,6 +248,7 @@ def __init__(
247248
self.save_ckpt_folder = get_config_value(ckpt_config, "save_ckpt_folder", None)
248249
self.oss_snapshot_freq: int = get_config_value(ckpt_config, "oss_snapshot_freq", 50)
249250
self.stop_file_path = get_config_value(ckpt_config, "stop_file_path", None)
251+
self.meta_data = meta_data
250252
if self.save_ckpt_folder:
251253
self.snapshot_ckpt_folder = get_config_value(
252254
ckpt_config, "snapshot_ckpt_folder", os.path.join(self.save_ckpt_folder, "snapshot")
@@ -629,6 +631,10 @@ def save_checkpoint(
629631
save_optimizer_checkpoint(optim=optimizer, state_path=folder)
630632
timer("save-optimizer").stop()
631633

634+
if gpc.get_global_rank() == 0:
635+
assert self.meta_data is not None
636+
llm_save(os.path.join(folder, "metadata.pt"), saved_obj=self.meta_data)
637+
632638
if (
633639
hasattr(train_state, "data_state_dict")
634640
and gpc.get_local_rank(ParallelMode.TENSOR) == 0

internlm/core/trainer_builder.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from internlm.model.metrics import AccPerplex
2121
from internlm.monitor.monitor import send_alert_message
2222
from internlm.train.pipeline import (
23+
generate_meta_data,
2324
get_scheduler_hooks,
2425
initialize_llm_profile,
2526
initialize_optimizer,
@@ -124,8 +125,13 @@ def __init__(
124125
# initialize optimizer
125126
optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model, isp_communicator)
126127

128+
# generate ckpt metaData
129+
meta_data = generate_meta_data(optimizer)
130+
127131
# initialize checkpoint manager and try resume training
128-
self.ckpt_manager = self._initialize_checkpoint_manager(model, optimizer, lr_scheduler, train_dl, config_lines)
132+
self.ckpt_manager = self._initialize_checkpoint_manager(
133+
model, optimizer, lr_scheduler, train_dl, config_lines, meta_data
134+
)
129135
self.ckpt_manager.try_resume_training(train_state, self.current_time)
130136

131137
# initialize customed llm writer
@@ -178,7 +184,7 @@ def _initialize_criterion(self) -> FlashGPTLMLoss:
178184
)
179185

180186
def _initialize_checkpoint_manager(
181-
self, model, optimizer, lr_scheduler, train_dl, config_lines
187+
self, model, optimizer, lr_scheduler, train_dl, config_lines, meta_data
182188
) -> CheckpointManager:
183189
return CheckpointManager(
184190
ckpt_config=gpc.config.ckpt,
@@ -189,6 +195,7 @@ def _initialize_checkpoint_manager(
189195
model_config=gpc.config.model,
190196
model_config_file="".join(config_lines),
191197
feishu_address=gpc.config.monitor.alert.feishu_alert_address,
198+
meta_data=meta_data,
192199
)
193200

194201
def _initialize_writer(self, train_state, config_lines) -> Writer:

internlm/initialize/launch.py

+3
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,9 @@ def args_sanity_check():
214214
if "enable_save_ckpt" not in ckpt:
215215
ckpt._add_item("enable_save_ckpt", True)
216216

217+
if "generate_meta_data" not in ckpt:
218+
ckpt._add_item("generate_meta_data", dict(enable=False, path=None))
219+
217220
# Saving checkpoint args.
218221
if ckpt.enable_save_ckpt:
219222
assert "checkpoint_every" in ckpt, "If enable save checkpoint, must give checkpoint_every in config.data!"

internlm/model/modules/embedding.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -47,27 +47,35 @@ def __init__(
4747
self.vocab_parallel = vocab_parallel
4848

4949
parallel_size = gpc.weight_parallel_size if is_using_isp() else gpc.tensor_parallel_size
50+
rank = gpc.get_local_rank(ParallelMode.WEIGHT) if is_using_isp() else gpc.get_local_rank(ParallelMode.TENSOR)
5051

5152
if vocab_parallel:
5253
assert num_embeddings % parallel_size == 0, f"{num_embeddings} is not divisible by {parallel_size}"
5354

5455
self.num_embeddings_per_partition = num_embeddings // parallel_size
5556
self.embed_dim_per_partition = embedding_dim
56-
self.vocab_start_index = gpc.get_local_rank(ParallelMode.TENSOR) * self.num_embeddings_per_partition
57+
self.vocab_start_index = rank * self.num_embeddings_per_partition
5758
self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition
59+
self.offset = [self.vocab_start_index, 0]
60+
self.tp_dim = 0
5861
else:
5962
assert embedding_dim % parallel_size == 0, f"{embedding_dim} is not divisible by {parallel_size}"
6063

6164
self.num_embeddings_per_partition = num_embeddings
6265
self.embed_dim_per_partition = embedding_dim // parallel_size
6366
self.vocab_start_index = 0
6467
self.vocab_end_index = self.num_embeddings_per_partition
68+
self.offset = [0, self.embed_dim_per_partition * rank]
69+
self.tp_dim = 1
6570

6671
self.weight = nn.Parameter(
6772
torch.empty((self.num_embeddings_per_partition, self.embed_dim_per_partition), dtype=dtype)
6873
)
69-
74+
self.complete_size = [num_embeddings, embedding_dim]
7075
setattr(self.weight, "is_embedding_param", True)
76+
setattr(self.weight, "offset", self.offset)
77+
setattr(self.weight, "complete_size", [num_embeddings, embedding_dim])
78+
setattr(self.weight, "tp_dim", self.tp_dim)
7179

7280
def forward(self, input_: Tensor) -> Tensor:
7381
if self.vocab_parallel and not is_using_isp():

internlm/model/modules/linear.py

+10
Original file line numberDiff line numberDiff line change
@@ -597,6 +597,7 @@ def __init__(
597597

598598
world_size = gpc.get_world_size(parallel_mode)
599599
rank = gpc.get_local_rank(parallel_mode)
600+
self.offset = None
600601

601602
if split_mode != "none":
602603
split_features = out_features if split_mode == "column" else in_features
@@ -611,11 +612,20 @@ def __init__(
611612

612613
if split_mode == "column":
613614
super().__init__(in_features, local_multiple * multiple_of, bias=bias, device=device, dtype=dtype)
615+
self.offset = [rank * local_multiple * multiple_of, 0]
616+
self.tp_dim = 0
614617
elif split_mode == "row":
615618
super().__init__(local_multiple * multiple_of, out_features, bias=bias, device=device, dtype=dtype)
619+
self.offset = [0, rank * local_multiple * multiple_of]
620+
self.tp_dim = 1
616621
else:
617622
super().__init__(in_features, out_features, bias=bias, device=device, dtype=dtype)
618623

624+
self.complete_size = [out_features, in_features]
625+
setattr(self.weight, "offset", self.offset)
626+
setattr(self.weight, "complete_size", [out_features, in_features])
627+
setattr(self.weight, "tp_dim", self.tp_dim)
628+
619629
def forward(self, input: torch.Tensor, batch_sizes: torch.Tensor = None) -> torch.Tensor: # pylint: disable=W0622
620630
_class_name = self.__class__.__name__
621631
assert self._communicator is not None, f"{_class_name} should register with a communicator first."

internlm/model/modules/mha.py

+4-20
Original file line numberDiff line numberDiff line change
@@ -65,22 +65,6 @@ def _qkv_pre_load_convert(module: "GQA", state_dict, prefix: str, *args, **kwarg
6565
)
6666

6767

68-
def _qkv_save_convert(module: "GQA", state_dict, prefix: str, *args, **kwargs) -> Dict: # pylint: disable=W0613
69-
wq_name, wk_name, wv_name, fused_name = (
70-
f"{prefix}wq.weight",
71-
f"{prefix}wk.weight",
72-
f"{prefix}wv.weight",
73-
f"{prefix}wqkv.weight",
74-
)
75-
76-
if module.enable_qkv_fusion:
77-
state_dict[wq_name], state_dict[wk_name], state_dict[wv_name] = split_fused_wqkv_weight(
78-
state_dict.pop(fused_name), *args, **kwargs
79-
)
80-
81-
return state_dict
82-
83-
8468
class MHA(nn.Module):
8569
"""
8670
Multi-head self-attention and cross-attention.
@@ -462,15 +446,15 @@ def __init__(
462446
if enable_qkv_fusion:
463447
assert bias is False, "Fuesd wqkv only support bias is False."
464448
self.wqkv = new_linear("wqkv", embed_dim, q_dim + 2 * self.kv_dim, bias, **factory_kwargs)
465-
self._register_load_state_dict_pre_hook(
466-
partial(_qkv_pre_load_convert, q_dim=q_dim, kv_dim=self.kv_dim), with_module=True
467-
)
468-
self._register_state_dict_hook(partial(_qkv_save_convert, q_dim=q_dim, kv_dim=self.kv_dim))
469449
else:
470450
self.wq = new_linear("wq", embed_dim, q_dim, bias, **factory_kwargs)
471451
self.wk = new_linear("wk", embed_dim, self.kv_dim, bias, **factory_kwargs)
472452
self.wv = new_linear("wv", embed_dim, self.kv_dim, bias, **factory_kwargs)
473453

454+
self._register_load_state_dict_pre_hook(
455+
partial(_qkv_pre_load_convert, q_dim=q_dim, kv_dim=self.kv_dim), with_module=True
456+
)
457+
474458
self.inner_attn = SelfAttention(
475459
causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout, layer_idx=layer_idx
476460
)

internlm/solver/optimizer/hybrid_zero_optim.py

+20-1
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ def __init__(
149149
assert self._param_bcast_sync_handler is not None
150150

151151
self._isp_communicator = isp_communicator
152-
152+
self.meta_for_zero = None
153153
# iterate over the param group in the optimizer
154154
# partition these param groups for data parallel training
155155
# and add buffers to parameter store for future access
@@ -165,6 +165,9 @@ def __init__(
165165
zero_mode = param_group["optimizer_mode"]
166166
self._zero_local_rank.append(gpc.get_local_rank(zero_mode))
167167
self._zero_world_size.append(gpc.get_world_size(zero_mode))
168+
169+
if self.meta_for_zero is None:
170+
self.meta_for_zero = [{} for _ in range(gpc.get_world_size(zero_mode))]
168171
# TODO _broadcast_parallel_mode is not only used in broadcast, maybe can change its name
169172
self._broadcast_parallel_mode.append(zero_mode)
170173

@@ -278,6 +281,22 @@ def _partition_param_list(self, group_id, param_group):
278281
else:
279282
rank_to_go = numel_per_rank.index(min(numel_per_rank))
280283
params_per_rank[rank_to_go].append(param)
284+
285+
if group_id not in self.meta_for_zero[rank_to_go]:
286+
self.meta_for_zero[rank_to_go][group_id] = {}
287+
288+
from internlm.train.pipeline import map_fqn_local_to_global
289+
290+
global_fqn = map_fqn_local_to_global[param.fqn] if param.fqn in map_fqn_local_to_global else param.fqn
291+
self.meta_for_zero[rank_to_go][group_id][global_fqn] = {
292+
"tp_dim": getattr(param, "tp_dim", -1),
293+
"pp": gpc.get_local_rank(ParallelMode.PIPELINE),
294+
"zero1": rank_to_go,
295+
"fqn": param.fqn,
296+
"shape": param.shape,
297+
"group_id": group_id,
298+
}
299+
281300
self.params_per_rank_id_dict[-1][rank_to_go].append(global_id)
282301
numel_per_rank[rank_to_go] += param.numel()
283302

0 commit comments

Comments
 (0)