Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add license 1 #5

Open
wants to merge 17 commits into
base: master
Choose a base branch
from
21 changes: 21 additions & 0 deletions LICENSE
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
MIT License

Copyright (c) 2023 Tarek Naous

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def generate_response(text):
Generated example:

```python
input = "! الله يلعن هالبلد انقطعت الكهرباء"
input = "! انقطعت الكهرباء"
generate_response(input)

#Generated response
Expand Down
1 change: 1 addition & 0 deletions model/bert2bert-load-pretrained.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"bert2bert-load-pretrained.ipynb","provenance":[],"collapsed_sections":[]},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","source":["This notebook shows you how you can easily load the pre-trained BERT2BERT model for empathetic response generation in Arabic and use it to generate response on your input."],"metadata":{"id":"4BHdjrADiujc"}},{"cell_type":"code","metadata":{"id":"IIxGYlg4R_pf","colab":{"base_uri":"https://localhost:8080/"},"outputId":"42d4da69-802c-4296-c80b-c783608149ef"},"source":["#Install dependencies\n","!pip install transformers\n","!pip install pyarabic\n","!pip install farasapy\n","!git clone https://github.com/aub-mind/arabert"],"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["Successfully installed huggingface-hub-0.2.1 pyyaml-6.0 sacremoses-0.0.46 tokenizers-0.10.3 transformers-4.15.0\n"]}]},{"cell_type":"code","metadata":{"id":"PhAhYVm_SCCz"},"source":["#Import transformers\n","import transformers"],"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#Load arabert preprocessor\n","from arabert.preprocess import ArabertPreprocessor\n","arabert_prep = ArabertPreprocessor(model_name=\"bert-base-arabert\", keep_emojis=False)"],"metadata":{"id":"yn_xdNIG-CR6"},"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"XrjbhIs_g50g","colab":{"base_uri":"https://localhost:8080/"},"outputId":"8187938e-57ef-42b0-e226-9047e87e1fa6"},"source":["#Load Model\n","from transformers import EncoderDecoderModel, AutoTokenizer\n","tokenizer = AutoTokenizer.from_pretrained(\"tareknaous/bert2bert-empathetic-response-msa\")\n","model = EncoderDecoderModel.from_pretrained(\"tareknaous/bert2bert-empathetic-response-msa\")\n","\n","model.to(\"cuda\")\n","model.eval()\n","print(\"done\")"],"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["done\n"]}]},{"cell_type":"code","source":["#Function to generater response/post-process\n","def generate_response(text):\n"," text_clean = arabert_prep.preprocess(text)\n"," inputs = tokenizer.encode_plus(text_clean,return_tensors='pt')\n"," outputs = model.generate(input_ids = inputs.input_ids.to(\"cuda\"),\n"," attention_mask = inputs.attention_mask.to(\"cuda\"),\n"," num_beams=1,\n"," do_sample = True,\n"," min_length=10,\n"," top_k = 50,\n"," temperature = 1,\n"," length_penalty =2)\n"," preds = tokenizer.batch_decode(outputs) \n"," response = str(preds)\n"," response = response.replace(\"\\'\", '')\n"," response = response.replace(\"[[CLS]\", '')\n"," response = response.replace(\"[SEP]]\", '')\n"," response = str(arabert_prep.desegment(response))\n"," return response"],"metadata":{"id":"X2uJIDi_96bX"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#Trial\n","text = \"الله يلعن هالبلد انقطعت الكهربا !\"\n","generate_response(text)"],"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":35},"id":"TndxODmm-YHK","outputId":"38b5fd7e-a110-45e4-bd01-5ee6db08a654"},"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"application/vnd.google.colaboratory.intrinsic+json":{"type":"string"},"text/plain":["' بالطبع ، لم يكن علي كالذهاب إلى المدرسة من قبل '"]},"metadata":{},"execution_count":12}]}]}
1 change: 1 addition & 0 deletions model/bert2bert-train.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"bert2bert-train.ipynb","provenance":[],"collapsed_sections":[]},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","source":["In this notebook, you are shown **how to train** a BERT2BERT model initialized with AraBERT pre-trained parameters on the Arabic empathetic message-response dataset. A gradio demo is also provided at the end."],"metadata":{"id":"NoMWZf5QikF0"}},{"cell_type":"code","metadata":{"id":"NZh4x-Matncr"},"source":["#Install dependencies\n","!pip install git-python==1.0.3\n","!pip install sacrebleu==1.4.2\n","!pip install rouge_score\n","!pip install farasapy\n","!git clone https://github.com/aub-mind/arabert\n","!pip install pyarabic\n","!pip install datasets\n","!pip install transformers==4.2\n","!git clone https://github.com/tareknaous/dialectal-conv/"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"TWYSoFH-tLQa"},"source":["#Fetch dataset\n","!wget https://raw.githubusercontent.com/aub-mind/Arabic-Empathetic-Chatbot/master/arabic-empathetic-conversations.csv"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"xtIT7o7DtpV1"},"source":["import os\n","import numpy as np\n","import pandas as pd\n","from datasets import load_dataset \n","import transformers\n","from transformers import BertTokenizer, EncoderDecoderModel\n","from sacrebleu import corpus_bleu\n","from transformers import BertTokenizerFast, EncoderDecoderModel\n","from transformers import TrainingArguments\n","from dataclasses import dataclass, field\n","from typing import Optional"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"CV9zIXY6t6lX"},"source":["encoder_max_length=75\n","decoder_max_length=75\n","model_name = \"aubmindlab/bert-base-arabert\"\n","\n","tokenizer = BertTokenizerFast.from_pretrained(model_name)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"2IjzpXipt7x4"},"source":["all_data = load_dataset(\"ArabicEmpatheticChatbot.py\")\n","train_data = all_data['train'].train_test_split(test_size=0.1,seed=42)['train']\n","val_data = all_data['train'].train_test_split(test_size=0.1,seed=42)['test']\n","dev_data = val_data.train_test_split(test_size=0.5,seed=42)['train']\n","test_data = val_data.train_test_split(test_size=0.5,seed=42)['test']"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"5Ey-O3tet9VX"},"source":["print(\"Length of train data\",len(train_data))\n","print(\"Length of dev data\",len(dev_data))\n","print(\"Length of test data\",len(test_data))"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"x-MyUbCwukal"},"source":["def process_data_to_model_inputs(batch): \n"," # Tokenizer will automatically set [BOS] <text> [EOS] \n"," inputs = tokenizer(batch[\"context\"], padding=\"max_length\", truncation=True, max_length=encoder_max_length)\n"," outputs = tokenizer(batch[\"response\"], padding=\"max_length\", truncation=True, max_length=decoder_max_length)\n"," \n"," batch[\"input_ids\"] = inputs.input_ids \n"," batch[\"attention_mask\"] = inputs.attention_mask \n"," batch[\"decoder_input_ids\"] = outputs.input_ids \n"," batch[\"labels\"] = outputs.input_ids.copy() \n"," # mask loss for padding \n"," batch[\"labels\"] = [ \n"," [-100 if token == tokenizer.pad_token_id else token for token in labels] for labels in batch[\"labels\"]\n"," ] \n"," batch[\"decoder_attention_mask\"] = outputs.attention_mask \n"," \n"," return batch"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"x2GbK_XHuea7"},"source":["batch_size=16"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"HDO04Sw6uOmk"},"source":["\n","train_data = train_data.map(\n"," process_data_to_model_inputs, \n"," batched=True, \n"," batch_size=batch_size, \n"," remove_columns=[\"context\", \"response\"],\n",")\n","train_data.set_format(\n"," type=\"torch\", columns=[\"input_ids\", \"attention_mask\", \"decoder_input_ids\", \"decoder_attention_mask\", \"labels\"],\n",")\n","\n","dev_data = dev_data.map(\n"," process_data_to_model_inputs, \n"," batched=True, \n"," batch_size=batch_size, \n"," remove_columns=[\"context\", \"response\"],\n",")\n","dev_data.set_format(\n"," type=\"torch\", columns=[\"input_ids\", \"attention_mask\", \"decoder_input_ids\", \"decoder_attention_mask\", \"labels\"],\n",")\n","\n","test_data = test_data.map(\n"," process_data_to_model_inputs, \n"," batched=True, \n"," batch_size=batch_size, \n"," remove_columns=[\"context\", \"response\"],\n",")\n","test_data.set_format(\n"," type=\"torch\", columns=[\"input_ids\", \"attention_mask\", \"decoder_input_ids\", \"decoder_attention_mask\", \"labels\"],\n",")"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"cpSEGGY-u4D7"},"source":["from transformers import EncoderDecoderModel\n","\n","arabert2arabert = EncoderDecoderModel.from_encoder_decoder_pretrained(model_name, model_name, tie_encoder_decoder=False)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"xVvr4Baeu6Z3"},"source":["#set special tokens\n","arabert2arabert.config.decoder_start_token_id = tokenizer.cls_token_id \n","arabert2arabert.config.eos_token_id = tokenizer.sep_token_id\n","arabert2arabert.config.pad_token_id = tokenizer.pad_token_id\n","\n","#sensible parameters for beam search\n","#set decoding params \n","arabert2arabert.config.max_length = 64\n","arabert2arabert.config.early_stopping = True\n","\n","arabert2arabert.config.num_beams = 1\n","arabert2arabert.config.vocab_size = arabert2arabert.config.encoder.vocab_size"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"OM3Rcr4Ky6T9"},"source":["@dataclass\n","class Seq2SeqTrainingArguments(TrainingArguments):\n"," label_smoothing: Optional[float] = field(\n"," default=0.0, metadata={\"help\": \"The label smoothing epsilon to apply (if not zero).\"}\n"," )\n"," sortish_sampler: bool = field(default=False, metadata={\"help\": \"Whether to SortishSamler or not.\"})\n"," predict_with_generate: bool = field(\n"," default=False, metadata={\"help\": \"Whether to use generate to calculate generative metrics (ROUGE, BLEU).\"}\n"," )\n"," adafactor: bool = field(default=False, metadata={\"help\": \"whether to use adafactor\"})\n"," encoder_layerdrop: Optional[float] = field(\n"," default=None, metadata={\"help\": \"Encoder layer dropout probability. Goes into model.config.\"}\n"," )\n"," decoder_layerdrop: Optional[float] = field(\n"," default=None, metadata={\"help\": \"Decoder layer dropout probability. Goes into model.config.\"}\n"," )\n"," dropout: Optional[float] = field(default=None, metadata={\"help\": \"Dropout probability. Goes into model.config.\"})\n"," attention_dropout: Optional[float] = field(\n"," default=None, metadata={\"help\": \"Attention dropout probability. Goes into model.config.\"}\n"," )\n"," lr_scheduler: Optional[str] = field(\n"," default=\"linear\", metadata={\"help\": f\"Which lr scheduler to use.\"}\n"," )"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"kcRPh8jUu9P-"},"source":["import torch\n","import torch.nn as nn\n","\n","def compute_metrics(pred):\n"," labels_ids = pred.label_ids\n"," #pred_ids = torch.argmax(pred.predictions,dim=2)\n"," pred_ids = pred.predictions \n","\n"," # all unnecessary tokens are removed\n"," pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)\n"," labels_ids[labels_ids == -100] = tokenizer.pad_token_id\n"," label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)\n","\n"," return {\"bleu\": round(corpus_bleu(pred_str , [label_str]).score, 4)}"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"5xhytujGu_IK"},"source":["#Set training arguments \n","training_args = Seq2SeqTrainingArguments(\n"," output_dir=\"./model\",\n"," per_device_train_batch_size=batch_size,\n"," per_device_eval_batch_size=batch_size//2,\n"," gradient_accumulation_steps = 2,\n"," predict_with_generate=True,\n"," do_eval=True,\n"," evaluation_strategy =\"epoch\",\n"," do_train=True,\n"," logging_steps=500, \n"," save_steps= 32965 // ( batch_size * 2), \n"," warmup_steps=100,\n"," eval_steps=10,\n"," #max_steps=16, # delete for full training\n"," num_train_epochs=5,# uncomment for full training\n"," overwrite_output_dir=True,\n"," save_total_limit=0,\n"," fp16=True, \n",")"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"Ip4vf0JwvJed"},"source":["# instantiate trainer\n","trainer = Seq2SeqTrainer(\n"," model=arabert2arabert,\n"," args=training_args,\n"," compute_metrics=compute_metrics,\n"," train_dataset=train_data,\n"," eval_dataset=dev_data,\n",")"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"ymQ6xubTvKm5"},"source":["#Train\n","trainer.train()"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"T18Ghyn2IYER"},"source":["#Evaluate\n","eval_output = trainer.evaluate()"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"wmpVC-XbIdD-"},"source":["#Compute perplexity\n","import math\n","perplexity = math.exp(eval_output[\"eval_loss\"])\n","print('\\nEvaluate Perplexity: {:10,.2f}'.format(perplexity))"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"yv4pAbXTIePL"},"source":["#Save tokenizer and model\n","trainer._save(\"./arabert2arabert\")\n","tokenizer.save_pretrained(\"./arabert2arabert\")"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"ggBQe7ZKJDBp"},"source":["**Gradio Demo** \\\\\n","This allows you to create a sharable web application of the model"]},{"cell_type":"code","source":["!pip install gradio\n","import gradio as gr"],"metadata":{"id":"P-Gr46vHh7Hk"},"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"ux9WG4Q-JEns"},"source":["from transformers import EncoderDecoderModel, AutoTokenizer\n","from datasets import load_dataset \n","from arabert.preprocess import ArabertPreprocessor\n","from torch.utils.data.dataloader import DataLoader\n","from transformers import default_data_collator\n","from torch.utils.data.sampler import SequentialSampler\n","import torch\n","from tqdm.notebook import tqdm"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"t9ss8ggxJrzd"},"source":["model_name=\"bert-base-arabert\"\n","arabert_prep = ArabertPreprocessor(model_name=model_name, keep_emojis=False)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"TuHRn5CSJGa3"},"source":["tokenizer = AutoTokenizer.from_pretrained(\"./arabert2arabert\")\n","model = EncoderDecoderModel.from_pretrained(\"./arabert2arabert\")\n","\n","model.to(\"cuda\")\n","model.eval()\n","print(\"done\")"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"97v_PxuHI700"},"source":["def generate_response(text, minimum_length, k):\n"," text_clean = arabert_prep.preprocess(text)\n"," inputs = tokenizer.encode_plus(text_clean,return_tensors='pt')\n"," outputs = model.generate(input_ids = inputs.input_ids.to(\"cuda\"),\n"," attention_mask = inputs.attention_mask.to(\"cuda\"),\n"," num_beams=1,\n"," do_sample = True,\n"," min_length=minimum_length,\n"," top_k = k,\n"," temperature = 1,\n"," length_penalty =2)\n"," preds = tokenizer.batch_decode(outputs) \n"," response = str(preds)\n"," response = response.replace(\"\\'\", '')\n"," response = response.replace(\"[[CLS]\", '')\n"," response = response.replace(\"[SEP]]\", '')\n"," response = str(arabert_prep.desegment(response))\n"," return response"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"zBVySYgFJSQt"},"source":["gr.Interface(fn=generate_response,\n"," inputs=[\n"," gr.inputs.Textbox(),\n"," gr.inputs.Slider(5, 20, step=1, label='Minimum Output Length'),\n"," gr.inputs.Slider(10, 1000, step=10, label='Top-K'),\n"," ],\n"," outputs=\"text\").launch(share=True)"],"execution_count":null,"outputs":[]}]}