Skip to content

Commit

Permalink
dpo ref模型加入
Browse files Browse the repository at this point in the history
  • Loading branch information
boy-hack committed Dec 7, 2023
1 parent 4edf222 commit cb7eb8b
Showing 1 changed file with 112 additions and 21 deletions.
133 changes: 112 additions & 21 deletions train_dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@
pre_lora_train_path = "" # 如果要继续上一个lora训练,这里填上上一个lora训练的地址
lora_rank = 8
lora_alpha = 32
# ref model
ref_device = "cuda:1"

global_pic = {
"step": [],
Expand Down Expand Up @@ -90,7 +92,7 @@ def prepare_ref_model():
trust_remote_code=True,
)
config.use_cache = False
model = AutoModelForCausalLM.from_pretrained(pre_train_path, trust_remote_code=True, device_map="cuda:2",
model = AutoModelForCausalLM.from_pretrained(pre_train_path, trust_remote_code=True, device_map=ref_device,
config=config)
model.eval()
return model
Expand Down Expand Up @@ -220,6 +222,89 @@ def dpo_loss(policy_chosen_logps: torch.FloatTensor,
return losses, chosen_rewards, rejected_rewards


def reformat_sft(instruction, input):
if input:
prefix = (
"Below is an instruction that describes a task, paired with an input that provides further context. "
"Write a response that appropriately completes the request.\n"
f"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
)
else:
prefix = (
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.\n"
f"### Instruction:\n{instruction}\n\n### Response:"
)
return prefix


def llm_decode(model, question):
question = reformat_sft(question, "")

generation_kwargs = {
"top_p": 0.8,
"temperature": 0.1,
"max_new_tokens": 2048,
"no_repeat_ngram_size": 4,
"do_sample": True,
}

inputs = tokenizer.encode(question, return_tensors='pt', truncation=True)
inputs = inputs.cuda()
output = model.generate(input_ids=inputs, **generation_kwargs)[0]
ret = tokenizer.decode(output, skip_special_tokens=True)[len(question):]
return ret


def batch_matrix():
questions = [
"SQL注入的基本原理",
"sqlmap -r 1.txt --dbs 是什么意思",
"详细描述一下CVE-2019-0708",
"蓝队做的再好,对于甲方赚钱来说有用吗?",
"你认为股票外汇等金融产品依靠交易系统可能长久稳定的盈利吗?",
'''这是一道CTF,请告诉我解题思路
<?php
$v1=0;$v2=0;$v3=0;
$a=(array)unserialize(@$_GET['foo']);
if(is_array($a)){
is_numeric(@$a["param1"])?die("nope"):NULL;
if(@$a["param1"]){
($a["param1"]>2017)?$v1=1:NULL;
}
if(is_array(@$a["param2"])){
if(count($a["param2"])!==5 OR !is_array($a["param2"][0])) die("nope");
$pos = array_search("nudt", $a["param2"]);
$pos===false?die("nope"):NULL;
foreach($a["param2"] as $key=>$val){
$val==="nudt"?die("nope"):NULL;
}
$v2=1;
}
}
$c=@$_GET['egg'];
$d=@$_GET['fish'];
if(@$c[1]){
if(!strcmp($c[1],$d) && $c[1]!==$d){
eregi("M|n|s",$d.$c[0])?die("nope"):NULL;
strpos(($c[0].$d), "MyAns")?$v3=1:NULL;
}
}
if($v1 && $v2 && $v3){
include "flag.php";
echo $flag;
}
?>
'''.strip()

]
print("生成测试")
for index, question in enumerate(questions):
print(f"{index}.{question}")
print(llm_decode(model_engine, question))


def train(model, reference_model, epoch):
global global_pic, global_step
data_engine = prepare_data()
Expand All @@ -235,15 +320,19 @@ def train(model, reference_model, epoch):
choose_labels_ids = item["choose_labels_ids"]
rejected_input_ids = item["rejected_input_ids"]
reject_labels_ids = item["reject_labels_ids"]
# with torch.no_grad():
# chosen_input_ids = chosen_input_ids.cuda(2)
# choose_labels_ids = choose_labels_ids.cuda(2)
# rejected_input_ids = rejected_input_ids.cuda(2)
# reject_labels_ids = reject_labels_ids.cuda(2)
# reference_chosen_logits = reference_model.forward(input_ids=chosen_input_ids,
# labels=choose_labels_ids).logits
# reference_rejected_logits = reference_model.forward(input_ids=rejected_input_ids,
# labels=reject_labels_ids).logits
with torch.no_grad():
chosen_input_ids = chosen_input_ids.to(ref_device)
choose_labels_ids = choose_labels_ids.to(ref_device)
rejected_input_ids = rejected_input_ids.to(ref_device)
reject_labels_ids = reject_labels_ids.to(ref_device)
reference_chosen_logits = reference_model(input_ids=chosen_input_ids,
labels=choose_labels_ids).logits
reference_chosen_logps = _get_batch_logps(reference_chosen_logits, choose_labels_ids,
average_log_prob=False)
reference_rejected_logits = reference_model(input_ids=rejected_input_ids,
labels=reject_labels_ids).logits
reference_rejected_logps = _get_batch_logps(reference_rejected_logits, reject_labels_ids,
average_log_prob=False)

chosen_input_ids = chosen_input_ids.cuda()
choose_labels_ids = choose_labels_ids.cuda()
Expand All @@ -255,15 +344,15 @@ def train(model, reference_model, epoch):

policy_rejected_logits = model(input_ids=rejected_input_ids, labels=reject_labels_ids).logits.to(torch.float32)
policy_rejected_logps = _get_batch_logps(policy_rejected_logits, reject_labels_ids, average_log_prob=False)
# reference_chosen_logits = reference_chosen_logits.to(choose_labels_ids.device)
# reference_rejected_logits = reference_rejected_logits.to(reject_labels_ids.device)
# reference_chosen_logps = _get_batch_logps(reference_chosen_logits, choose_labels_ids, average_log_prob=False,
# tokenizer=tokenizer)
# reference_rejected_logps = _get_batch_logps(reference_rejected_logits, reject_labels_ids,
# average_log_prob=False, tokenizer=tokenizer)

reference_chosen_logps = reference_chosen_logps.cuda()
reference_rejected_logps = reference_rejected_logps.cuda()
loss, chosen_rewards, rejected_rewards = dpo_loss(
policy_chosen_logps, policy_rejected_logps, torch.FloatTensor(0), torch.FloatTensor(0), reference_free=True)
policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps,
reference_free=False)
reward_accuracies = (chosen_rewards > rejected_rewards).float().cpu().mean()
margins = (chosen_rewards - rejected_rewards).cpu().mean()

show_loss = loss.mean().item()
running_loss += show_loss
epoch_loss += show_loss
Expand All @@ -285,18 +374,20 @@ def train(model, reference_model, epoch):
pbar.set_postfix({
"step": step,
"loss": show_loss,
"chosen_rewards": chosen_rewards.item(),
"rejected_rewards": rejected_rewards.item()
"reward_accuracies": reward_accuracies.item(),
"margins": margins.item()
})
pbar.update(1)
step += 1
global_step += 1

print(f"epoch:{epoch} loss:{epoch_loss / step}")
global_pic["step"].append(global_step)
global_pic["loss"].append(epoch_loss / step)
save_loss_pic()
pbar.close()
save_model(model_engine, f"{output_dir}/secgpt-base-epoch-{i + 1}")
batch_matrix()


if __name__ == "__main__":
Expand All @@ -307,9 +398,9 @@ def train(model, reference_model, epoch):
if not os.path.isdir(output_dir):
os.mkdir(output_dir)
model_engine = prepare_model()
# ref_model = prepare_ref_model()
ref_model = prepare_ref_model()

optimizer = AdamW(model_engine.parameters(), lr=lr, correct_bias=True)

for i in range(num_train_epochs):
train(model_engine, None, i)
train(model_engine, ref_model, i)

0 comments on commit cb7eb8b

Please sign in to comment.