Skip to content

Commit 1dcb081

Browse files
committed
s
1 parent 967ed92 commit 1dcb081

File tree

2 files changed

+686
-0
lines changed

2 files changed

+686
-0
lines changed

docs/ai/chat-glm3-funetune.md

+335
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,335 @@
1+
# ChatGLM3-6B 微调
2+
3+
本目录提供 ChatGLM3-6B 模型的微调示例,包括全量微调和 P-Tuning v2。格式上,提供多轮对话微调样例和输入输出格式微调样例。
4+
5+
如果将模型下载到了本地,本文和代码中的 `THUDM/chatglm3-6b` 字段均应替换为相应地址以从本地加载模型。
6+
7+
运行示例需要 `python>=3.10`,除基础的 `torch` 依赖外,示例代码运行还需要依赖。
8+
9+
**我们提供了 [示例notebook](lora_finetune.ipynb) 用于演示如何使用我们的微调代码。**
10+
11+
```bash
12+
pip install -r requirements.txt
13+
```
14+
15+
## 测试硬件标准
16+
17+
我们仅提供了单机多卡/多机多卡的运行示例,因此您需要至少一台具有多个 GPU 的机器。本仓库中的**默认配置文件**中,我们记录了显存的占用情况:
18+
19+
+ SFT 全量微调: 4张显卡平均分配,每张显卡占用 `48346MiB` 显存。
20+
+ P-TuningV2 微调: 1张显卡,占用 `18426MiB` 显存。
21+
+ LORA 微调: 1张显卡,占用 `14082MiB` 显存。
22+
23+
> 请注意,该结果仅供参考,对于不同的参数,显存占用可能会有所不同。请结合你的硬件情况进行调整。
24+
25+
> 请注意,我们仅仅使用英伟达 Hopper(代表显卡:H100) 和 Ampère(代表显卡:A100) 架构和系列显卡做过测试。如果您使用其他架构的显卡,可能会出现
26+
> 1. 未知的训练问题 / 显存占用与上述有误差。
27+
> 2. 架构过低而不支持某些特性。
28+
> 3. 推理效果问题。
29+
> 以上三种情况为社区曾经遇到过的问题,虽然概率极地,如果您遇到了以上问题,可以尝试在社区中解决。
30+
31+
## 多轮对话格式
32+
33+
多轮对话微调示例采用 ChatGLM3 对话格式约定,对不同角色添加不同 `loss_mask` 从而在一遍计算中为多轮回复计算 `loss`
34+
35+
对于数据文件,样例采用如下格式
36+
37+
如果您仅希望微调模型的对话能力,而非工具能力,您应该按照以下格式整理数据。
38+
39+
```json
40+
[
41+
{
42+
"conversations": [
43+
{
44+
"role": "system",
45+
"content": "<system prompt text>"
46+
},
47+
{
48+
"role": "user",
49+
"content": "<user prompt text>"
50+
},
51+
{
52+
"role": "assistant",
53+
"content": "<assistant response text>"
54+
},
55+
// ... Muti Turn
56+
{
57+
"role": "user",
58+
"content": "<user prompt text>"
59+
},
60+
{
61+
"role": "assistant",
62+
"content": "<assistant response text>"
63+
}
64+
]
65+
}
66+
// ...
67+
]
68+
```
69+
70+
**请注意,这种方法在微调的step较多的情况下会影响到模型的工具调用功能**
71+
72+
如果您希望微调模型的对话和工具能力,您应该按照以下格式整理数据。
73+
74+
```json
75+
[
76+
{
77+
"tools": [
78+
// available tools, format is not restricted
79+
],
80+
"conversations": [
81+
{
82+
"role": "system",
83+
"content": "<system prompt text>"
84+
},
85+
{
86+
"role": "user",
87+
"content": "<user prompt text>"
88+
},
89+
{
90+
"role": "assistant",
91+
"content": "<assistant thought to text>"
92+
},
93+
{
94+
"role": "tool",
95+
"name": "<name of the tool to be called",
96+
"parameters": {
97+
"<parameter_name>": "<parameter_value>"
98+
},
99+
"observation": "<observation>"
100+
// don't have to be string
101+
},
102+
{
103+
"role": "assistant",
104+
"content": "<assistant response to observation>"
105+
},
106+
// ... Muti Turn
107+
{
108+
"role": "user",
109+
"content": "<user prompt text>"
110+
},
111+
{
112+
"role": "assistant",
113+
"content": "<assistant response text>"
114+
}
115+
]
116+
}
117+
// ...
118+
]
119+
```
120+
121+
- 关于工具描述的 system prompt 无需手动插入,预处理时会将 `tools` 字段使用 `json.dumps(..., ensure_ascii=False)`
122+
格式化后插入为首条 system prompt。
123+
124+
- 每种角色可以附带一个 `bool` 类型的 `loss` 字段,表示该字段所预测的内容是否参与 `loss`
125+
计算。若没有该字段,样例实现中默认对 `system`, `user` 不计算 `loss`,其余角色则计算 `loss`
126+
127+
- `tool` 并不是 ChatGLM3 中的原生角色,这里的 `tool` 在预处理阶段将被自动转化为一个具有工具调用 `metadata``assistant`
128+
角色(默认计算 `loss`)和一个表示工具返回值的 `observation` 角色(不计算 `loss`)。
129+
130+
- 目前暂未实现 `Code interpreter` 的微调任务。
131+
132+
- `system` 角色为可选角色,但若存在 `system` 角色,其必须出现在 `user`
133+
角色之前,且一个完整的对话数据(无论单轮或者多轮对话)只能出现一次 `system` 角色。
134+
135+
## 数据集格式示例
136+
137+
这里以 AdvertiseGen 数据集为例,
138+
您可以从 [Google Drive](https://drive.google.com/file/d/13_vf0xRTQsyneRKdD1bZIr93vBGOczrk/view?usp=sharing)
139+
或者 [Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/b3f119a008264b1cabd1/?dl=1) 下载 AdvertiseGen 数据集。
140+
将解压后的 AdvertiseGen 目录放到 `data` 目录下并自行转换为如下格式数据集。
141+
142+
> 请注意,现在的微调代码中加入了验证集,因此,对于一组完整的微调数据集,必须包含训练数据集和验证数据集,测试数据集可以不填写。或者直接用验证数据集代替。
143+
144+
```
145+
{"conversations": [{"role": "user", "content": "类型#裙*裙长#半身裙"}, {"role": "assistant", "content": "这款百搭时尚的仙女半身裙,整体设计非常的飘逸随性,穿上之后每个女孩子都能瞬间变成小仙女啦。料子非常的轻盈,透气性也很好,穿到夏天也很舒适。"}]}
146+
```
147+
148+
## 配置文件
149+
150+
微调配置文件位于 `config` 目录下,包括以下文件:
151+
152+
1. `ds_zereo_2 / ds_zereo_3.json`: deepspeed 配置文件。
153+
2. `lora.yaml / ptuning.yaml / sft.yaml`: 模型不同方式的配置文件,包括模型参数、优化器参数、训练参数等。 部分重要参数解释如下:
154+
+ data_config 部分
155+
+ train_file: 训练数据集的文件路径。
156+
+ val_file: 验证数据集的文件路径。
157+
+ test_file: 测试数据集的文件路径。
158+
+ num_proc: 在加载数据时使用的进程数量。
159+
+ max_input_length: 输入序列的最大长度。
160+
+ max_output_length: 输出序列的最大长度。
161+
+ training_args 部分
162+
+ output_dir: 用于保存模型和其他输出的目录。
163+
+ max_steps: 训练的最大步数。
164+
+ per_device_train_batch_size: 每个设备(如 GPU)的训练批次大小。
165+
+ dataloader_num_workers: 加载数据时使用的工作线程数量。
166+
+ remove_unused_columns: 是否移除数据中未使用的列。
167+
+ save_strategy: 模型保存策略(例如,每隔多少步保存一次)。
168+
+ save_steps: 每隔多少步保存一次模型。
169+
+ log_level: 日志级别(如 info)。
170+
+ logging_strategy: 日志记录策略。
171+
+ logging_steps: 每隔多少步记录一次日志。
172+
+ per_device_eval_batch_size: 每个设备的评估批次大小。
173+
+ evaluation_strategy: 评估策略(例如,每隔多少步进行一次评估)。
174+
+ eval_steps: 每隔多少步进行一次评估。
175+
+ predict_with_generate: 是否使用生成模式进行预测。
176+
+ generation_config 部分
177+
+ max_new_tokens: 生成的最大新 token 数量。
178+
+ peft_config 部分
179+
+ peft_type: 使用的参数有效调整类型(如 LORA)。
180+
+ task_type: 任务类型,这里是因果语言模型(CAUSAL_LM)。
181+
+ Lora 参数:
182+
+ r: LoRA 的秩。
183+
+ lora_alpha: LoRA 的缩放因子。
184+
+ lora_dropout: 在 LoRA 层使用的 dropout 概率
185+
+ P-TuningV2 参数:
186+
+ num_virtual_tokens: 虚拟 token 的数量。
187+
188+
## 开始微调
189+
190+
通过以下代码执行 **单机多卡/多机多卡** 运行。
191+
192+
```angular2html
193+
cd finetune_demo
194+
OMP_NUM_THREADS=1 torchrun --standalone --nnodes=1 --nproc_per_node=8 finetune_hf.py data/AdvertiseGen/ THUDM/chatglm3-6b configs/lora.yaml --deepspeed ds_zero_2.json
195+
```
196+
197+
通过以下代码执行 **单机单卡** 运行。
198+
199+
```angular2html
200+
cd finetune_demo
201+
python finetune_hf.py data/AdvertiseGen/ THUDM/chatglm3-6b configs/lora.yaml
202+
```
203+
204+
单机及多机的第四参数(no)为是否断点继训,可输入类型有三种
205+
1:no 直接重新训练
206+
2:yes 自动从最后一个保存的 Checkpoint开始训练
207+
3:XX 断点号数字 例 600 则从序号600 Checkpoint开始训练
208+
209+
## 使用微调后的模型
210+
211+
### 在 inference_hf.py 中验证微调后的模型
212+
213+
您可以在 `finetune_demo/inference_hf.py` 中使用我们的微调后的模型,仅需要一行代码就能简单的进行测试。
214+
215+
```angular2html
216+
python inference_hf.py your_finetune_path --prompt your prompt
217+
```
218+
219+
这样,得到的回答就微调后的回答了。
220+
221+
### 在本仓库的其他 demo 或者外部仓库使用微调后的模型
222+
223+
您可以在任何一个 demo 内使用我们的 `lora` 和 全参微调的模型。这需要你自己按照以下教程进行修改代码。
224+
225+
1. 使用`finetune_demo/inference_hf.py`中读入模型的方式替换 demo 中读入模型的方式。
226+
227+
> 请注意,对于 LORA 和 P-TuningV2 我们没有合并训练后的模型,而是在`adapter_config.json`
228+
> 中记录了微调型的路径,如果你的原始模型位置发生更改,则你应该修改`adapter_config.json``base_model_name_or_path`的路径。
229+
230+
```python
231+
def load_model_and_tokenizer(
232+
model_dir: Union[str, Path], trust_remote_code: bool = True
233+
) -> tuple[ModelType, TokenizerType]:
234+
model_dir = _resolve_path(model_dir)
235+
if (model_dir / 'adapter_config.json').exists():
236+
model = AutoPeftModelForCausalLM.from_pretrained(
237+
model_dir, trust_remote_code=trust_remote_code, device_map='auto'
238+
)
239+
tokenizer_dir = model.peft_config['default'].base_model_name_or_path
240+
else:
241+
model = AutoModelForCausalLM.from_pretrained(
242+
model_dir, trust_remote_code=trust_remote_code, device_map='auto'
243+
)
244+
tokenizer_dir = model_dir
245+
tokenizer = AutoTokenizer.from_pretrained(
246+
tokenizer_dir, trust_remote_code=trust_remote_code
247+
)
248+
return model, tokenizer
249+
```
250+
251+
2. 读取微调的模型,请注意,你应该使用微调模型的位置,例如,若你的模型位置为`/path/to/finetune_adapter_model`
252+
,原始模型地址为`path/to/base_model`,则你应该使用`/path/to/finetune_adapter_model`作为`model_dir`
253+
3. 完成上述操作后,就能正常使用微调的模型了,其他的调用方式没有变化。
254+
255+
### 提示
256+
257+
1. 微调代码在开始训练前,会先打印首条训练数据的预处理信息(默认已经注释,可以解除注释),显示为
258+
259+
```log
260+
Sanity
261+
Check >> >> >> >> >> >> >
262+
'[gMASK]': 64790 -> -100
263+
'sop': 64792 -> -100
264+
'<|system|>': 64794 -> -100
265+
'': 30910 -> -100
266+
'\n': 13 -> -100
267+
'Answer': 20115 -> -100
268+
'the': 267 -> -100
269+
'following': 1762 -> -100
270+
...
271+
'know': 683 -> -100
272+
'the': 267 -> -100
273+
'response': 3010 -> -100
274+
'details': 3296 -> -100
275+
'.': 30930 -> -100
276+
'<|assistant|>': 64796 -> -100
277+
'': 30910 -> 30910
278+
'\n': 13 -> 13
279+
'I': 307 -> 307
280+
'need': 720 -> 720
281+
'to': 289 -> 289
282+
'use': 792 -> 792
283+
...
284+
<< << << << << << < Sanity
285+
Check
286+
```
287+
288+
字样,每行依次表示一个 detokenized string, token_id 和 target_id。其中,`target_id``token_id`在模型词表中的索引,`-100`表示该
289+
token 不参与 `loss` 计算。
290+
291+
2. `_prepare_model_for_training` 的作用是遍历模型的所有可训练参数,并确保它们的数据类型为`torch.float32`
292+
这在某些情况下是必要的,因为混合精度训练或其他操作可能会更改模型参数的数据类型。该代码默打开,可以注释,但是如果使用
293+
`half` 格式训练出现问题,可以切换回这个代码,显存可能增加。
294+
3. 在我们的[Huggingface模型代码](https://huggingface.co/THUDM/chatglm3-6b/blob/main/modeling_chatglm.py)中,有以下内容:
295+
```python
296+
if self.gradient_checkpointing and self.training:
297+
layer_ret = torch.utils.checkpoint.checkpoint(
298+
layer,
299+
hidden_states,
300+
attention_mask,
301+
rotary_pos_emb,
302+
kv_caches[index],
303+
use_cache,
304+
use_reentrant=False
305+
)
306+
```
307+
这可能导致训练的时候显存增加,因此,如果您的显存不足,可以尝试将``` use_reentrant``` 修改为`True`
308+
4. 微调后的模型可以使用任何支持 `peft` 载入的模型加速框架,在这里,我们没有提供demo。
309+
5. 本仓库的微调数据集格式与 API 微调数据集格式有一定区别
310+
+ ZhipuAI API 微调数据集中的 `messages` 字段在本仓库为 `conversation` 字段。
311+
+ ZhipuAI API 中的微调文件为 `jsonl`, 在本仓库,需要简单的将文件名改为 `json`
312+
313+
## 参考文献
314+
315+
```
316+
317+
@inproceedings{liu2022p,
318+
title={P-tuning: Prompt tuning can be comparable to fine-tuning across scales and tasks},
319+
author={Liu, Xiao and Ji, Kaixuan and Fu, Yicheng and Tam, Weng and Du, Zhengxiao and Yang, Zhilin and Tang, Jie},
320+
booktitle={Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (Volume 2: Short
321+
Papers)},
322+
pages={61--68},
323+
year={2022}
324+
}
325+
326+
@misc{tang2023toolalpaca,
327+
title={ToolAlpaca: Generalized Tool Learning for Language Models with 3000 Simulated Cases},
328+
author={Qiaoyu Tang and Ziliang Deng and Hongyu Lin and Xianpei Han and Qiao Liang and Le Sun},
329+
year={2023},
330+
eprint={2306.05301},
331+
archivePrefix={arXiv},
332+
primaryClass={cs.CL}
333+
}
334+
335+
```

0 commit comments

Comments
 (0)