Skip to content

Commit cbb9f44

Browse files
committedApr 10, 2023
Save PrefixEncoder params only
1 parent 4478546 commit cbb9f44

File tree

5 files changed

+3858
-5
lines changed

5 files changed

+3858
-5
lines changed
 

‎ptuning/README.md

+12
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,18 @@ bash train.sh
3939
```shell
4040
bash evaluate.sh
4141
```
42+
**[2023/04/10更新]** 在 P-tuning v2 训练时模型只保存 PrefixEncoder 部分的参数,在推理时需要同时载入原 ChatGLM-6B 模型以及 PrefixEncoder 的 Checkpoint,因此需要指定参数(已更新 `evaluate.sh`) :
43+
44+
```shell
45+
--model_name_or_path THUDM/chatglm-6b
46+
--ptuning_checkpoint $CHECKPOINT_PATH
47+
```
48+
49+
仍然兼容旧版全参保存的 Checkpoint,只需要跟之前一样设定 `model_name_or_path`
50+
51+
```shell
52+
--model_name_or_path $CHECKPOINT_PATH
53+
```
4254

4355
评测指标为中文 Rouge score 和 BLEU-4。生成的结果保存在
4456
`./output/adgen-chatglm-6b-pt-8-1e-2/generated_predictions.txt`

‎ptuning/evaluate.sh

+2-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@ CUDA_VISIBLE_DEVICES=0 python3 main.py \
99
--overwrite_cache \
1010
--prompt_column content \
1111
--response_column summary \
12-
--model_name_or_path ./output/$CHECKPOINT/checkpoint-$STEP \
12+
--model_name_or_path THUDM/chatglm-6b \
13+
--ptuning_checkpoint ./output/$CHECKPOINT/checkpoint-$STEP \
1314
--output_dir ./output/$CHECKPOINT \
1415
--overwrite_output_dir \
1516
--max_source_length 64 \

‎ptuning/main.py

+19-3
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import jieba
2929
from rouge_chinese import Rouge
3030
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
31+
import torch
3132

3233
import transformers
3334
from transformers import (
@@ -110,13 +111,28 @@ def main():
110111

111112
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)
112113

113-
model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, trust_remote_code=True)
114+
if model_args.ptuning_checkpoint is not None:
115+
# Evaluation
116+
# Loading extra state dict of prefix encoder
117+
model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, trust_remote_code=True)
118+
prefix_state_dict = torch.load(os.path.join(model_args.ptuning_checkpoint, "pytorch_model.bin"))
119+
new_prefix_state_dict = {}
120+
for k, v in prefix_state_dict.items():
121+
new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
122+
model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
123+
else:
124+
model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, trust_remote_code=True)
114125

115126
if model_args.quantization_bit is not None:
116127
print(f"Quantized to {model_args.quantization_bit} bit")
117128
model = model.quantize(model_args.quantization_bit)
118-
model = model.half()
119-
model.transformer.prefix_encoder.float()
129+
if model_args.pre_seq_len is not None:
130+
# P-tuning v2
131+
model = model.half()
132+
model.transformer.prefix_encoder.float()
133+
else:
134+
# Finetune
135+
model = model.float()
120136

121137
prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
122138

0 commit comments

Comments
 (0)