|
223 | 223 | "metadata": {},
|
224 | 224 | "outputs": [],
|
225 | 225 | "source": [
|
226 |
| - "# Define template: `response_template_ids`, `collator`, `format_prompt`\n", |
227 |
| - "response_template_ids = tokenizer.encode(response_template, add_special_tokens=False)[1:]\n", |
228 |
| - "collator = DataCollatorForCompletionOnlyLM(response_template_ids, tokenizer=tokenizer)\n", |
229 |
| - "\n", |
230 | 226 | "def format_prompt(example) -> Tuple[str, str]:\n",
|
231 | 227 | " return f\"{example['context']} \\n-- Question: {example['question']}{response_template}\", example['answer']"
|
232 | 228 | ]
|
|
323 | 319 | " return out\n",
|
324 | 320 | "\n",
|
325 | 321 | " dataset.set_transform(batched_mutate)\n",
|
326 |
| - " return dataset # dataset.map(batched_mutate, batched=True, num_proc=num_proc, remove_columns=original_columns)" |
| 322 | + " return dataset" |
327 | 323 | ]
|
328 | 324 | },
|
329 | 325 | {
|
|
427 | 423 | " model_ref, # The model with peft adapters turned off will be used as a reference model if not provided\n",
|
428 | 424 | " tokenizer=tokenizer,\n",
|
429 | 425 | " train_dataset=ds,\n",
|
430 |
| - " # eval_dataset=eval_dataset,\n", |
431 |
| - " beta=0.2, # TODO: determine\n", |
| 426 | + " beta=0.2, \n", |
432 | 427 | " max_length=2048,\n",
|
433 |
| - " # max_target_length=248,\n", |
434 | 428 | " max_prompt_length=1500,\n",
|
435 | 429 | " args=TrainingArguments(\n",
|
436 | 430 | " output_dir=\"./dpo_results\",\n",
|
437 | 431 | " optim=\"paged_adamw_32bit\",\n",
|
438 | 432 | "\n",
|
439 | 433 | " max_grad_norm=0.3,\n",
|
440 | 434 | " warmup_ratio=0.03,\n",
|
441 |
| - " # group_by_length=True,\n", |
442 | 435 | " \n",
|
443 | 436 | " learning_rate=2e-4,\n",
|
444 | 437 | " weight_decay=0.001,\n",
|
|
533 | 526 | "metadata": {},
|
534 | 527 | "outputs": [],
|
535 | 528 | "source": [
|
536 |
| - "# Fine-tuned model\n", |
| 529 | + "# Save trained model\n", |
537 | 530 | "new_model = \"open_llama_3b_v2_sft_plus_dpo\"\n",
|
538 | 531 | "trainer.model.save_pretrained(new_model)"
|
539 | 532 | ]
|
|
0 commit comments