-
Notifications
You must be signed in to change notification settings - Fork 3k
/
Copy pathrun_embedding.py
286 lines (251 loc) Β· 11.5 KB
/
run_embedding.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# import inspect
import os
import sys
import paddle
from utils.argument import EmbeddingArgument
from paddlenlp.data import DataCollatorForEmbedding
from paddlenlp.datasets import EmbeddingIterableDataset, load_dataset
from paddlenlp.trainer import PdArgumentParser, get_last_checkpoint, set_seed
from paddlenlp.trainer.trainer_callback import TrainerState
from paddlenlp.transformers import (
AutoConfig,
AutoTokenizer,
Qwen2Config,
Qwen2SentenceEmbedding,
)
from paddlenlp.transformers.configuration_utils import LlmMetaConfig
from paddlenlp.trl import DataConfig, EmbeddingTrainer, ModelConfig, SFTConfig
from paddlenlp.trl.llm_utils import compute_metrics, init_chat_template
from paddlenlp.utils.log import logger
# Fine-tune Environment Variables to support sharding stage1 overlap optimization.
os.environ["USE_CASUAL_MASK"] = "False"
def main():
parser = PdArgumentParser((ModelConfig, DataConfig, SFTConfig, EmbeddingArgument))
if len(sys.argv) >= 2 and sys.argv[1].endswith(".json"):
model_args, data_args, training_args, embedding_args = parser.parse_json_file_and_cmd_lines()
else:
model_args, data_args, training_args, embedding_args = parser.parse_args_into_dataclasses()
training_args.print_config(model_args, "Model")
training_args.print_config(data_args, "Data")
# Setup GPU & distributed training
paddle.set_device(training_args.device)
set_seed(seed=training_args.seed)
logger.warning(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, world_size: {training_args.world_size}, "
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16 or training_args.bf16}"
)
if training_args.pipeline_parallel_degree > 1:
raise NotImplementedError("Cannot support pipeline parallel for Embedding training now.")
# Detecting last checkpoint.
last_checkpoint = None
if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
last_checkpoint = get_last_checkpoint(training_args.output_dir)
if last_checkpoint is not None and training_args.resume_from_checkpoint is None:
logger.info(
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
)
# Load model
if training_args.fp16_opt_level == "O2":
if training_args.fp16:
dtype = "float16"
elif training_args.bf16:
dtype = "bfloat16"
else:
raise ValueError("Please specific dtype: --fp16 or --bf16")
else:
dtype = "float32"
model_config = AutoConfig.from_pretrained(
model_args.model_name_or_path,
dtype=dtype,
from_aistudio=model_args.from_aistudio,
)
assert isinstance(model_config, Qwen2Config), "Now only qwen2 supported"
LlmMetaConfig.set_llm_config(model_config, training_args)
model_config.use_fast_layer_norm = model_args.use_fast_layer_norm
# Config for model using dropout, such as GPT.
if hasattr(model_config, "hidden_dropout_prob"):
model_config.hidden_dropout_prob = model_args.hidden_dropout_prob
if hasattr(model_config, "attention_probs_dropout_prob"):
model_config.attention_probs_dropout_prob = model_args.attention_probs_dropout_prob
if hasattr(model_config, "ignore_index"):
model_config.ignore_index = -100
if model_args.fuse_attention_qkv is not None:
model_config.fuse_attention_qkv = model_args.fuse_attention_qkv
if model_args.fuse_attention_ffn is not None:
model_config.fuse_attention_ffn = model_args.fuse_attention_ffn
model_config.seq_length = data_args.max_length
model_config.embedding_negatives_cross_device = embedding_args.embedding_negatives_cross_device
logger.info(f"Final model config: {model_config}")
model_class = Qwen2SentenceEmbedding
if model_args.continue_training and not training_args.autotuner_benchmark:
model = model_class.from_pretrained(
model_args.model_name_or_path,
config=model_config,
from_aistudio=model_args.from_aistudio,
)
else:
model = model_class.from_config(model_config, dtype=dtype)
if model_args.flash_mask and (not data_args.zero_padding or not model.config.use_flash_attention):
logger.warning("`flash_mask` must use with zero padding and flash attention.")
data_args.zero_padding = True
model.config.use_flash_attention = True
# Load tokenizer & dataset
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, from_aistudio=model_args.from_aistudio)
# init chat_template for tokenizer
init_chat_template(tokenizer, model_args.model_name_or_path, data_args.chat_template)
# if using chat_template, data_args.eval_with_do_generation must be false
if tokenizer.chat_template is not None:
data_args.eval_with_do_generation = False
if training_args.do_eval:
logger.warning("Warning: 'do_eval' is set to True, but will be set to False for Embedding training currently.")
training_args.do_eval = False
training_args.evaluation_strategy = "no"
if data_args.dataset_name_or_path is None:
raise ValueError(f"Please specific dataset name or path (got {data_args.dataset_name_or_path})")
elif os.path.exists(os.path.join(data_args.dataset_name_or_path, "train.json")) or os.path.exists(
os.path.join(data_args.dataset_name_or_path, "dev.json")
):
if training_args.do_train:
train_ds = load_dataset(
"json",
data_files=os.path.join(data_args.dataset_name_or_path, "train.json"),
lazy=data_args.lazy,
)[0]
else:
train_ds = None
if training_args.do_eval:
dev_ds = load_dataset(
"json",
data_files=os.path.join(data_args.dataset_name_or_path, "dev.json"),
lazy=data_args.lazy,
)[0]
else:
dev_ds = None
elif os.path.exists(os.path.join(data_args.dataset_name_or_path, "train")) or os.path.exists(
os.path.join(data_args.dataset_name_or_path, "dev")
):
import glob
if training_args.do_train:
train_ds = load_dataset(
"json",
data_files=glob.glob(os.path.join(data_args.dataset_name_or_path, "train", "*.json")),
lazy=data_args.lazy,
)[0]
else:
train_ds = None
if training_args.do_eval:
dev_ds = load_dataset(
"json",
data_files=glob.glob(os.path.join(data_args.dataset_name_or_path, "dev", "*.json")),
lazy=data_args.lazy,
)[0]
else:
dev_ds = None
else:
if training_args.do_train:
train_ds = load_dataset(data_args.dataset_name_or_path, splits=["train"])[0]
else:
train_ds = None
if training_args.do_eval:
dev_ds = load_dataset(data_args.dataset_name_or_path, splits=["dev"])[0]
else:
dev_ds = None
# TODO(ZHUI & sijunhe): Temporary implementation. Generalize this logic and move to Trainer later.
if training_args.resume_from_checkpoint is not None and data_args.lazy:
logger.info(
f"Loading from '{training_args.resume_from_checkpoint}' with `lazy=True`, manually skipping dataset and setting `ignore_data_skip` to True."
)
training_args.ignore_data_skip = True
state = TrainerState.load_from_json(os.path.join(training_args.resume_from_checkpoint, "trainer_state.json"))
if state.trial_params is not None and "zero_padding_global_step" in state.trial_params:
consumed_samples = state.trial_params["zero_padding_global_step"]
else:
consumed_samples = (
state.global_step
* training_args.per_device_train_batch_size
* training_args.gradient_accumulation_steps
* training_args.dataset_world_size
)
logger.info(
f"Skipping the first {consumed_samples} samples to warmup the dataset from checkpoint '{training_args.resume_from_checkpoint}'."
)
train_ds = train_ds.skip(consumed_samples)
if train_ds is not None:
train_ds = EmbeddingIterableDataset(
train_ds,
tokenizer,
max_query_len=embedding_args.max_query_len,
max_passage_len=embedding_args.max_passage_len,
group_size=embedding_args.group_size,
query_template=embedding_args.query_template,
passage_template=embedding_args.passage_template,
)
if dev_ds is not None:
dev_ds = EmbeddingIterableDataset(
dev_ds,
tokenizer,
max_query_len=embedding_args.max_query_len,
max_passage_len=embedding_args.max_passage_len,
group_size=embedding_args.group_size,
query_template=embedding_args.query_template,
passage_template=embedding_args.passage_template,
)
# Create trainer
if data_args.pad_to_max_length:
padding = "max_length"
else:
padding = True
data_collator_fn = DataCollatorForEmbedding(
tokenizer=tokenizer,
max_query_len=embedding_args.max_query_len,
padding=padding,
max_passage_len=embedding_args.max_passage_len,
return_tensors="np",
return_attention_mask=not model_args.flash_mask,
pad_to_multiple_of=data_args.pad_to_multiple_of,
)
trainer = EmbeddingTrainer(
model=model,
model_args=embedding_args,
args=training_args,
train_dataset=train_ds,
eval_dataset=dev_ds,
tokenizer=tokenizer,
compute_metrics=compute_metrics,
data_collator=data_collator_fn,
)
trainable_parameters = [p for p in model.parameters() if not p.stop_gradient]
trainer.set_optimizer_grouped_parameters(trainable_parameters)
# Train
if training_args.do_train:
checkpoint = None
if training_args.resume_from_checkpoint is not None:
checkpoint = training_args.resume_from_checkpoint
elif last_checkpoint is not None:
checkpoint = last_checkpoint
train_result = trainer.train(resume_from_checkpoint=checkpoint)
trainer.save_model(merge_tensor_parallel=training_args.tensor_parallel_degree > 1)
trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()
# Evaluation dev set
if training_args.do_eval:
logger.info("*** Evaluate result after train ***")
eval_result = trainer.evaluate(dev_ds)
trainer.log_metrics("eval", eval_result)
if __name__ == "__main__":
main()