Skip to content

Commit

Permalink
updates
Browse files Browse the repository at this point in the history
  • Loading branch information
nelson-lojo committed Dec 14, 2023
1 parent 33dbe6d commit c30c5b7
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 89 deletions.
7 changes: 7 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
dpo_results*/
sft_results*/
ppo_results*/
open_llama_3b_v2_*/
.Trash-1000/
.ipynb_checkpoints/
__pycache__/
122 changes: 36 additions & 86 deletions dpo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
"text": [
"/usr/lib/python3/dist-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.17.3 and <1.25.0 is required for this version of SciPy (detected version 1.25.2\n",
" warnings.warn(f\"A NumPy version >={np_minversion} and <{np_maxversion}\"\n",
"2023-12-13 03:17:18.778257: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n",
"2023-12-13 03:17:18.828290: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
"2023-12-14 04:28:31.767500: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n",
"2023-12-14 04:28:31.815854: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
"To enable the following instructions: AVX512F AVX512_VNNI, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
"/home/ubuntu/.local/lib/python3.10/site-packages/trl/trainer/ppo_config.py:141: UserWarning: The `optimize_cuda_cache` arguement will be deprecated soon, please use `optimize_device_cache` instead.\n",
" warnings.warn(\n"
Expand All @@ -34,8 +34,7 @@
" TrainingArguments,\n",
")\n",
"from trl import (\n",
" DPOTrainer, \n",
" DataCollatorForCompletionOnlyLM\n",
" DPOTrainer\n",
")\n",
"\n",
"SFT_ADAPTER_DIRECTORY = \"./open_llama_3b_v2_sft/\""
Expand All @@ -44,9 +43,7 @@
{
"cell_type": "markdown",
"id": "7c6e0696-a01a-47ad-8df9-615ef6b14a0a",
"metadata": {
"jp-MarkdownHeadingCollapsed": true
},
"metadata": {},
"source": [
"## Model prep"
]
Expand Down Expand Up @@ -80,7 +77,7 @@
"# Load LLaMA tokenizer\n",
"tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)\n",
"tokenizer.pad_token = tokenizer.eos_token\n",
"tokenizer.padding_side = \"right\""
"tokenizer.padding_side = \"left\""
]
},
{
Expand Down Expand Up @@ -170,55 +167,6 @@
{
"cell_type": "code",
"execution_count": 9,
"id": "21f425ef-6e0a-4061-a829-0482a292fb17",
"metadata": {},
"outputs": [],
"source": [
"def print_tokens_with_ids(txt):\n",
" tokens = tokenizer.tokenize(txt, add_special_tokens=False)\n",
" token_ids = tokenizer.encode(txt, add_special_tokens=False)\n",
" print(list(zip(tokens, token_ids)))"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "9201b1a0-7ebc-4486-aa73-35096b3de7bc",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[('▁', 29500), ('<0x0A>', 13), ('--', 559), ('▁Answer', 13910), (':', 29537), ('<0x0A>', 13)]\n"
]
}
],
"source": [
"print_tokens_with_ids(response_template)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "5ed99aab-ae87-4043-93f6-5c1e1bd82df8",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[('▁--', 1472), ('▁Question', 10706), (':', 29537), ('▁HI', 27003), ('<0x0A>', 13), ('--', 559), ('▁Answer', 13910), (':', 29537), ('<0x0A>', 13)]\n"
]
}
],
"source": [
"print_tokens_with_ids(f\"-- Question: HI{response_template}\")"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "57c7b1d3-d997-485c-86c4-5f9653c87edd",
"metadata": {},
"outputs": [],
Expand All @@ -229,13 +177,9 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 10,
"id": "b6c42ff0-bc6f-45b8-a1b8-c1fabf9f2529",
"metadata": {
"jupyter": {
"source_hidden": true
}
},
"metadata": {},
"outputs": [],
"source": [
"def mutate(response, num_tokens=1):\n",
Expand All @@ -260,17 +204,17 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 11,
"id": "9500a0da-9674-4035-b4c5-3a356c8bfc2f",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'-- Hello, this is a long sentence RodSignature string mutation'"
"'-- Hello, this눉 ausive sentence to demonstrate string mutation'"
]
},
"execution_count": 14,
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -281,7 +225,7 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 12,
"id": "f41a9378-8b13-4782-9db7-7e71127cf5ca",
"metadata": {},
"outputs": [],
Expand Down Expand Up @@ -319,12 +263,12 @@
" return out\n",
"\n",
" dataset.set_transform(batched_mutate)\n",
" return dataset"
" return dataset # dataset.map(batched_mutate, batched=True, num_proc=num_proc, remove_columns=original_columns)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 13,
"id": "38925ecf-c61e-4eae-83ca-60ca3653b36e",
"metadata": {},
"outputs": [
Expand All @@ -337,7 +281,7 @@
"})"
]
},
"execution_count": 16,
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -349,7 +293,7 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 14,
"id": "5aa180fa-b4b5-4f9c-93b1-440f4def1486",
"metadata": {},
"outputs": [
Expand All @@ -358,10 +302,10 @@
"text/plain": [
"{'prompt': 'CREATE TABLE head (age INTEGER) \\n-- Question: How many heads of the departments are older than 56 ?\\n-- Answer:\\n',\n",
" 'chosen': 'SELECT count(*) FROM head WHERE age > 56',\n",
" 'rejected': 'SELECT count offset Ker Dolத WHERE age > experience56'}"
" 'rejected': \"SELECT countK) '@ head wip Dav++)> 56\"}"
]
},
"execution_count": 17,
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -372,7 +316,7 @@
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": 15,
"id": "b6ca7662-ea7a-4a85-91cd-46b3074f7560",
"metadata": {},
"outputs": [
Expand All @@ -381,10 +325,10 @@
"text/plain": [
"{'prompt': 'CREATE TABLE head (age INTEGER) \\n-- Question: How many heads of the departments are older than 56 ?\\n-- Answer:\\n',\n",
" 'chosen': 'SELECT count(*) FROM head WHERE age > 56',\n",
" 'rejected': 'SELECT count(*世 FROMicip AP ageading> 5FAULT'}"
" 'rejected': 'SELECT counterend) Joined head calendar age > ピ foster'}"
]
},
"execution_count": 18,
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -403,7 +347,7 @@
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": 16,
"id": "6dcd7ac6-3507-4605-b7fb-a77ff0141c9d",
"metadata": {},
"outputs": [
Expand All @@ -418,24 +362,30 @@
],
"source": [
"# Initialize Trainer\n",
"LORA_RANK = 4\n",
"assert LORA_RANK % 4 == 0, \"Please use a LoRA Rank divisible by 4\"\n",
"\n",
"trainer = DPOTrainer(\n",
" model,\n",
" model_ref, # The model with peft adapters turned off will be used as a reference model if not provided\n",
" tokenizer=tokenizer,\n",
" train_dataset=ds,\n",
" beta=0.2, \n",
" # eval_dataset=eval_dataset,\n",
" beta=0.2, # TODO: determine\n",
" max_length=2048,\n",
" # max_target_length=248,\n",
" max_prompt_length=1500,\n",
" args=TrainingArguments(\n",
" output_dir=\"./dpo_results\",\n",
" output_dir=f\"./dpo_results_r{LORA_RANK}\",\n",
" optim=\"paged_adamw_32bit\",\n",
"\n",
" max_grad_norm=0.3,\n",
" warmup_ratio=0.03,\n",
" # group_by_length=True,\n",
" \n",
" learning_rate=2e-4,\n",
" weight_decay=0.001,\n",
" num_train_epochs=1,\n",
" num_train_epochs=3,\n",
" max_steps=-1,\n",
" per_device_train_batch_size=2,\n",
" \n",
Expand All @@ -452,9 +402,9 @@
" report_to=\"tensorboard\"\n",
" ),\n",
" peft_config=LoraConfig(\n",
" lora_alpha=16,\n",
" lora_alpha=LORA_RANK // 4,\n",
" lora_dropout=0.1,\n",
" r=64,\n",
" r=LORA_RANK,\n",
" bias=\"none\",\n",
" task_type=\"CAUSAL_LM\",\n",
" )\n",
Expand Down Expand Up @@ -488,8 +438,8 @@
"\n",
" <div>\n",
" \n",
" <progress value='51' max='1981' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
" [ 51/1981 00:47 < 31:29, 1.02 it/s, Epoch 0.03/1]\n",
" <progress value='60' max='5943' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
" [ 60/5943 00:56 < 1:35:35, 1.03 it/s, Epoch 0.03/3]\n",
" </div>\n",
" <table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
Expand Down Expand Up @@ -526,8 +476,8 @@
"metadata": {},
"outputs": [],
"source": [
"# Save trained model\n",
"new_model = \"open_llama_3b_v2_sft_plus_dpo\"\n",
"# Fine-tuned model\n",
"new_model = f\"open_llama_3b_v2_sft_plus_dpo_r{LORA_RANK}\"\n",
"trainer.model.save_pretrained(new_model)"
]
},
Expand Down
6 changes: 3 additions & 3 deletions reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,15 @@ def get_similarity(candidate, reference) -> float:

return similarity

def get_row_diff(candidate_rows, solution_rows) -> float:
def get_row_diff(candidate_rows, solution_rows, eps=1e-5) -> float:
candidate_set = set(candidate_rows)
solution_set = set(solution_rows)

num_excess_rows = len(candidate_set - solution_set)
excess_proportion = num_excess_rows / len(candidate_set)
excess_proportion = num_excess_rows / (len(candidate_set) + eps)

num_missing_rows = len(solution_set - candidate_set)
missing_proportion = num_missing_rows / len(solution_set)
missing_proportion = num_missing_rows / (len(solution_set) + eps)
return (1-excess_proportion + 1-missing_proportion) / 2

def get_reward(db_name: str, candidate_query: str, solution_query: str):
Expand Down

0 comments on commit c30c5b7

Please sign in to comment.