Skip to content

Commit

Permalink
加入mistral对话训练
Browse files Browse the repository at this point in the history
  • Loading branch information
boy-hack committed Apr 15, 2024
1 parent c81a508 commit 350f689
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 3 deletions.
91 changes: 91 additions & 0 deletions dataset/mistral.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
from typing import List
import torch
from transformers import AutoTokenizer
import json
import random
import numpy as np


# <s>GPT4 Correct User: Hello, how are you?<|end_of_turn|>GPT4 Correct Assistant: I'm doing great. How can I help you today?<|end_of_turn|>GPT4 Correct User: I'd like to show off how chat templating works!<|end_of_turn|>
def make_context(context: List, tokenizer: AutoTokenizer):
start_token = "<s>"
stop_token = "<|end_of_turn|>"

def _tokenizer(text):
return tokenizer.encode(text, add_special_tokens=False)

prompt = start_token
stop_token_ids = _tokenizer(stop_token)
start_token_ids = _tokenizer(start_token)

input_ids = start_token_ids.copy()
labels = start_token_ids.copy()

for i in range(len(context)):
if i % 2 == 0:
pp = f"GPT4 Correct User: {context[i]['content']}"
prompt += pp
input_ids.extend(_tokenizer(pp))
length = len(_tokenizer(pp))
labels.extend(length * [-100])
else:
pp = f"GPT4 Correct Assistant: {context[i]['content']}"
prompt += pp
input_ids.extend(_tokenizer(pp))
labels.extend(_tokenizer(pp))

prompt += stop_token
input_ids.extend(stop_token_ids)
labels.extend(stop_token_ids)
return prompt, input_ids, labels


class DataEngine():
def __init__(self, tokenizer, micro_batch_size, max_length, checkpoint_step=0, data_path=""):
self.micro_batch_size = micro_batch_size
self.max_length = max_length
with open(data_path, encoding="utf-8") as f:
self.train_dataset = json.load(f)
random.shuffle(self.train_dataset)
self.tokenizer = tokenizer
self.index = checkpoint_step
self.data = []
for item in self.train_dataset:
_, input_ids, labels = make_context(
item,
tokenizer
)
self.data.append({
"input_ids": input_ids,
"labels": labels,
})

def get_data(self):
for item in self.data:
input_ids = item["input_ids"]
labels = item["labels"]
input_ids = torch.LongTensor(np.asarray(input_ids).reshape(1, self.max_length))
labels = torch.LongTensor(np.asarray(labels).reshape(1, self.max_length))
yield dict(input_ids=input_ids, labels=labels)

def __len__(self):
# 只训练前xx条数据
return len(self.data)


if __name__ == '__main__':
chat = [
{
"content":"你好"
},
{
"content": "11"
},
{
"content": "你好22"
},
{
"content": "222"
},
]
print(make_context(chat, AutoTokenizer.from_pretrained("FuseAI/FuseChat-7B-VaRM")))
9 changes: 7 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
import transformers
from peft import PeftModel, LoraConfig, TaskType, get_peft_model
from tqdm import tqdm
from transformers import AutoTokenizer, AdamW, AutoModelForCausalLM
from transformers import AutoTokenizer, AdamW, AutoModelForCausalLM, get_linear_schedule_with_warmup

from dataset import pretrain, sft, chatml
from dataset import pretrain, sft, chatml, mistral

global_pic = {
"step": [],
Expand Down Expand Up @@ -47,6 +47,10 @@ def prepare_data():
elif train_option == "chatml":
data_engine = chatml.DataEngine(tokenizer, batch_size, max_position_embeddings,
data_path=dataset_path)
elif train_option == "mistral":
data_engine = mistral.DataEngine(tokenizer, batch_size, max_position_embeddings,
data_path=dataset_path)

else:
raise ValueError("train_option must be one of pretrain, sft, pretrain_cache")
return data_engine
Expand Down Expand Up @@ -185,5 +189,6 @@ def train(model, epoch):
model_engine = prepare_model()
lr = config["learning_rate"]
optimizer = AdamW(model_engine.parameters(), lr=lr, correct_bias=True)
# scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps = warm_up_ratio * total_steps, num_training_steps = total_steps)
for i in range(int(config["num_train_epochs"])):
train(model_engine, i)
4 changes: 3 additions & 1 deletion webdemo/webdemo.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,8 @@ def evaluate(
do_sample=True,
repetition_penalty=repetition_penalty,
streamer=streamer,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.eos_token_id
)
c = Thread(target=lambda: model.generate(input_ids=input_ids, **generation_config))
c.start()
Expand Down Expand Up @@ -225,4 +227,4 @@ def inner(context, answer1, answer2, answer3, fankui):
parser.add_argument("--lora", type=str, help="lora模型")
parser.add_argument("--share_gradio", type=bool, default=False, help="开放外网访问")
args = parser.parse_args()
main(args.base_model, args.lora, args.share_gradio)
main(args.base_model, args.lora, args.share_gradio)

0 comments on commit 350f689

Please sign in to comment.