From c56fb36c092883bf07bfd8da6d863ab847186f70 Mon Sep 17 00:00:00 2001 From: david Date: Wed, 3 Nov 2021 11:04:04 +0300 Subject: [PATCH] add notebooks with the mined paraphrases --- .../finetune_t5_on_mined.ipynb | 2264 +++++++++++++++++ .../paranmt_mining.ipynb | 1153 +++++++++ 2 files changed, 3417 insertions(+) create mode 100644 emnlp2021/style_transfer/mining_parallel_corpus/finetune_t5_on_mined.ipynb create mode 100644 emnlp2021/style_transfer/mining_parallel_corpus/paranmt_mining.ipynb diff --git a/emnlp2021/style_transfer/mining_parallel_corpus/finetune_t5_on_mined.ipynb b/emnlp2021/style_transfer/mining_parallel_corpus/finetune_t5_on_mined.ipynb new file mode 100644 index 0000000..f94bce7 --- /dev/null +++ b/emnlp2021/style_transfer/mining_parallel_corpus/finetune_t5_on_mined.ipynb @@ -0,0 +1,2264 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "os.environ['CUDA_VISIBLE_DEVICES'] = '1'" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Train a paraphraser on the mined data" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(223823, 7)\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
Unnamed: 0referencetranslationsimilaritylenght_diffref_toxtrn_tox
191205191205She must be out there somewhere.he must be out there, damn it.0.6298550.0606060.0000610.995908
197885197885\"Nuts!\"bullshit!0.7438360.3636360.0283490.999605
143495143495You are only about one half loser, the other h...from the second half, you're the winner.0.7784990.2931030.9329300.000052
112935112935Kick some NSC booty.kick some NSC asses.0.9016880.0000000.3590500.998997
5273252732Because it's where they used to keep prisoners...because they held prisoners there before they ...0.8378650.1625000.9536260.074862
9510495104He wants us to blow him away!he wants us to kill him!0.7666680.1666670.0006620.999194
200609200609'You don't bash a man's brains out when he's t...\"you can't break the head of a man tied up in ...0.7570320.1111110.8937700.006597
8532185321\"Here's the bats, if you got the balls\".\"take them, if you have the balls!\"0.7862420.1219510.0003850.998311
7331373313I don't want to fucking hear it.I don't want to hear anything.0.8288450.0606060.9800570.000058
146935146935One, you can be a waitress, or you can be a ca...one that you're gonna be a waitress...... or y...0.7089990.2142860.0001420.999406
\n", + "
" + ], + "text/plain": [ + " Unnamed: 0 reference \\\n", + "191205 191205 She must be out there somewhere. \n", + "197885 197885 \"Nuts! \n", + "143495 143495 You are only about one half loser, the other h... \n", + "112935 112935 Kick some NSC booty. \n", + "52732 52732 Because it's where they used to keep prisoners... \n", + "95104 95104 He wants us to blow him away! \n", + "200609 200609 'You don't bash a man's brains out when he's t... \n", + "85321 85321 \"Here's the bats, if you got the balls\". \n", + "73313 73313 I don't want to fucking hear it. \n", + "146935 146935 One, you can be a waitress, or you can be a ca... \n", + "\n", + " translation similarity \\\n", + "191205 he must be out there, damn it. 0.629855 \n", + "197885 \"bullshit! 0.743836 \n", + "143495 from the second half, you're the winner. 0.778499 \n", + "112935 kick some NSC asses. 0.901688 \n", + "52732 because they held prisoners there before they ... 0.837865 \n", + "95104 he wants us to kill him! 0.766668 \n", + "200609 \"you can't break the head of a man tied up in ... 0.757032 \n", + "85321 \"take them, if you have the balls!\" 0.786242 \n", + "73313 I don't want to hear anything. 0.828845 \n", + "146935 one that you're gonna be a waitress...... or y... 0.708999 \n", + "\n", + " lenght_diff ref_tox trn_tox \n", + "191205 0.060606 0.000061 0.995908 \n", + "197885 0.363636 0.028349 0.999605 \n", + "143495 0.293103 0.932930 0.000052 \n", + "112935 0.000000 0.359050 0.998997 \n", + "52732 0.162500 0.953626 0.074862 \n", + "95104 0.166667 0.000662 0.999194 \n", + "200609 0.111111 0.893770 0.006597 \n", + "85321 0.121951 0.000385 0.998311 \n", + "73313 0.060606 0.980057 0.000058 \n", + "146935 0.214286 0.000142 0.999406 " + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df = pd.read_csv('filtered.tsv', sep='\\t', encoding='utf-8')\n", + "print(df.shape)\n", + "df.sample(10)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.5521639867216506" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "(df.ref_tox > df.trn_tox).mean()" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "xx = []\n", + "yy = []\n", + "for i, row in df.iterrows():\n", + " if row.ref_tox > row.trn_tox:\n", + " xx.append(row.reference)\n", + " yy.append(row.translation)\n", + " else:\n", + " yy.append(row.reference)\n", + " xx.append(row.translation)\n", + " \n", + "xydf = pd.DataFrame({'source': xx, 'target': yy})" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Prepare datasets" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "from transformers import (\n", + " AdamW,\n", + " T5ForConditionalGeneration,\n", + " T5Tokenizer, T5TokenizerFast,\n", + " get_linear_schedule_with_warmup\n", + ")\n", + "import torch" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "model_name = \"ceshine/t5-paraphrase-paws-msrp-opinosis\"" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "tokenizer = T5TokenizerFast.from_pretrained(model_name)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "from sklearn.model_selection import train_test_split" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "223523 300\n" + ] + } + ], + "source": [ + "df_train, df_test = train_test_split(xydf, test_size=300)\n", + "print(df_train.shape[0], df_test.shape[0])" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 56.1 s, sys: 1.78 s, total: 57.9 s\n", + "Wall time: 8.06 s\n" + ] + } + ], + "source": [ + "%%time\n", + "\n", + "x1 = tokenizer(df_train.source.tolist(), truncation=True)\n", + "y1 = tokenizer(df_train.target.tolist(), truncation=True)\n", + "x2 = tokenizer(df_test.source.tolist(), truncation=True)\n", + "y2 = tokenizer(df_test.target.tolist(), truncation=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(223523, 300)" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "class PairsDataset(torch.utils.data.Dataset):\n", + " def __init__(self, x, y):\n", + " self.x = x\n", + " self.y = y\n", + "\n", + " def __getitem__(self, idx):\n", + " assert idx < len(self.x['input_ids'])\n", + " item = {key: val[idx] for key, val in self.x.items()}\n", + " item['decoder_attention_mask'] = self.y['attention_mask'][idx]\n", + " item['labels'] = self.y['input_ids'][idx]\n", + " return item\n", + " \n", + " @property\n", + " def n(self):\n", + " return len(self.x['input_ids'])\n", + "\n", + " def __len__(self):\n", + " return self.n # * 2\n", + " \n", + "train_dataset = PairsDataset(x1, y1)\n", + "test_dataset = PairsDataset(x2, y2)\n", + "len(train_dataset), len(test_dataset)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "from torch.utils.data import Dataset, DataLoader" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "train_dataloader = DataLoader(train_dataset, batch_size=4, drop_last=True, shuffle=True, num_workers=1)\n", + "test_dataloader = DataLoader(test_dataset, batch_size=4, drop_last=True, shuffle=True, num_workers=1)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Fine tune t5" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "from transformers import (\n", + " AdamW,\n", + " T5ForConditionalGeneration,\n", + " T5Tokenizer,\n", + " get_linear_schedule_with_warmup\n", + ")\n", + "import torch" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "checkpoint_name = 'SkolkovoInstitute/t5-paraphrase-paws-msrp-opinosis-paranmt'" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "model = T5ForConditionalGeneration.from_pretrained(checkpoint_name)" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "device = torch.device('cuda:0')\n", + "model.to(device);" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "import transformers" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [], + "source": [ + "from transformers import Trainer, TrainingArguments\n", + "from transformers.file_utils import cached_property\n", + "from typing import Tuple\n", + "\n", + "class TrAr(TrainingArguments):\n", + " @cached_property\n", + " def _setup_devices(self):\n", + " return device" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [], + "source": [ + "from typing import List, Dict, Union\n", + "\n", + "class DataCollatorWithPadding:\n", + " def __init__(self, tokenizer):\n", + " self.tokenizer = tokenizer\n", + "\n", + " def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:\n", + " batch = self.tokenizer.pad(\n", + " features,\n", + " padding=True,\n", + " )\n", + " ybatch = self.tokenizer.pad(\n", + " {'input_ids': batch['labels'], 'attention_mask': batch['decoder_attention_mask']},\n", + " padding=True,\n", + " ) \n", + " batch['labels'] = ybatch['input_ids']\n", + " batch['decoder_attention_mask'] = ybatch['attention_mask']\n", + " \n", + " return {k: torch.tensor(v) for k, v in batch.items()}" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [], + "source": [ + "save_name = 'models/t5-cechine-nmt-mined-detox'" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "todo: maybe, batch > 4 would do as well" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [], + "source": [ + "training_args = TrAr(\n", + " output_dir=save_name, # output directory\n", + " overwrite_output_dir=True,\n", + " num_train_epochs=3, # total # of training epochs\n", + " per_device_train_batch_size=4, # batch size per device during training\n", + " gradient_accumulation_steps=4,\n", + " per_device_eval_batch_size=8, # batch size for evaluation\n", + " warmup_steps=300, # number of warmup steps for learning rate scheduler\n", + " weight_decay=0, # strength of weight decay\n", + " learning_rate=3e-5,\n", + " logging_dir='./logs', # directory for storing logs\n", + " logging_steps=100,\n", + " eval_steps=100,\n", + " evaluation_strategy='steps',\n", + " save_total_limit=1,\n", + " save_steps=5000,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [], + "source": [ + "data_collator = DataCollatorWithPadding(tokenizer=tokenizer)" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [], + "source": [ + "trainer = Trainer(\n", + " model=model, # the instantiated 🤗 Transformers model to be trained\n", + " args=training_args, # training arguments, defined above\n", + " train_dataset=train_dataset, # training dataset\n", + " eval_dataset=test_dataset, # evaluation dataset\n", + " data_collator=data_collator,\n", + " tokenizer=tokenizer,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [], + "source": [ + "import gc\n", + "gc.collect()\n", + "torch.cuda.empty_cache();" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": true, + "jupyter": { + "outputs_hidden": true + } + }, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "
\n", + " \n", + " \n", + " \n", + " [21330/41910 2:15:15 < 2:10:30, 2.63 it/s, Epoch 1.53/3]\n", + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
StepTraining LossValidation LossRuntimeSamples Per Second
1001.0846000.7934641.229200244.053000
2001.0902000.7881721.226800244.531000
3001.1094000.7853061.238100242.309000
4001.0773000.7846831.217100246.494000
5001.0891000.7830291.259200238.245000
6001.0772000.7814471.260200238.060000
7001.0769000.7796601.212500247.413000
8001.0579000.7788331.222500245.389000
9001.0781000.7788081.223800245.132000
10001.0673000.7783351.223300245.228000
11001.0510000.7782531.245100240.952000
12001.0429000.7757671.266100236.957000
13001.0700000.7757421.251100239.797000
14001.0749000.7749521.269100236.395000
15001.0539000.7753811.251400239.725000
16001.0815000.7717251.213400247.232000
17001.0809000.7716651.220500245.809000
18001.0755000.7713671.230700243.770000
19001.0754000.7706591.214500247.022000
20001.0396000.7709111.227600244.382000
21001.0433000.7711291.211900247.553000
22001.0429000.7690761.247500240.490000
23001.0620000.7691381.266100236.940000
24001.0857000.7685021.237500242.424000
25001.0536000.7674391.266000236.970000
26001.0297000.7674921.242600241.429000
27001.0492000.7663901.209600248.013000
28001.0466000.7667801.218400246.229000
29001.0553000.7670011.282000234.001000
30001.0735000.7647601.257000238.671000
31001.0637000.7654471.254900239.054000
32001.0532000.7644911.255600238.930000
33001.0252000.7643751.249100240.169000
34001.0837000.7641051.247100240.566000
35001.0301000.7636781.255300238.990000
36001.0651000.7632301.243700241.221000
37001.0643000.7620371.205700248.814000
38001.0320000.7625141.218300246.241000
39001.0373000.7615141.208400248.254000
40001.0552000.7613351.217000246.503000
41001.0640000.7616901.207100248.526000
42001.0663000.7620061.218900246.125000
43001.0611000.7607121.214600247.001000
44001.0584000.7613471.208400248.260000
45001.0485000.7605921.213300247.252000
46001.0352000.7596091.215600246.783000
47001.0609000.7591731.218700246.171000
48001.0520000.7590941.257100238.635000
49001.0526000.7593951.232700243.361000
50001.0363000.7594481.252600239.511000
51001.0821000.7595201.250100239.990000
52001.0252000.7583381.208200248.307000
53001.0769000.7570791.211300247.668000
54001.0710000.7568921.216700246.574000
55001.0425000.7565891.208200248.309000
56001.0788000.7563071.246000240.778000
57001.0397000.7569381.216500246.619000
58001.0374000.7561661.211700247.588000
59001.0784000.7561191.282300233.960000
60001.0606000.7556571.263600237.423000
61001.0378000.7559871.231000243.706000
62001.0482000.7562911.255300238.994000
63001.0531000.7553301.275900235.130000
64001.0298000.7555061.272900235.685000
65001.0572000.7551731.268600236.478000
66001.0624000.7547541.238500242.228000
67001.0386000.7553781.238600242.204000
68001.0481000.7545591.237700242.392000
69001.0208000.7532331.266900236.799000
70001.0105000.7529191.259100238.258000
71001.0183000.7537681.230400243.815000
72001.0551000.7529361.228700244.156000
73001.0318000.7538591.227300244.444000
74001.0239000.7534721.231600243.594000
75001.0488000.7531281.232500243.413000
76001.0509000.7531911.235600242.798000
77001.0463000.7529351.231200243.658000
78001.0553000.7525111.234200243.071000
79001.0441000.7534031.239000242.138000
80001.0409000.7528651.235300242.858000
81001.0385000.7522171.237500242.424000
82001.0119000.7524051.238100242.300000
83001.0334000.7522531.279700234.428000
84001.0346000.7527041.244700241.014000
85001.0256000.7527431.241000241.735000
86001.0300000.7528711.243200241.309000
87001.0250000.7519741.248200240.352000
88001.0264000.7509231.242600241.424000
89001.0359000.7514701.241100241.724000
90001.0739000.7508551.242900241.375000
91001.0578000.7503241.243500241.258000
92001.0714000.7495221.242900241.378000
93001.0087000.7497871.243900241.183000
94001.0190000.7500121.246600240.653000
95001.0265000.7498111.272300235.790000
96001.0123000.7499151.254800239.078000
97001.0507000.7489711.286300233.222000
98001.0281000.7489311.266300236.904000
99001.0375000.7493481.250800239.841000
100001.0163000.7491851.284500233.558000
101001.0045000.7492561.281500234.098000
102001.0166000.7490981.250200239.955000
103001.0212000.7478271.246300240.721000
104001.0214000.7482771.280900234.210000
105001.0090000.7489321.258700238.340000
106001.0384000.7489081.262000237.710000
107001.0063000.7486891.248100240.369000
108001.0006000.7485981.272400235.779000
109001.0204000.7486841.267700236.654000
110001.0271000.7473681.257500238.576000
111001.0199000.7477491.279100234.534000
112001.0439000.7472301.257400238.587000
113001.0478000.7470711.262600237.613000
114001.0636000.7469201.266600236.851000
115001.0166000.7471751.275900235.128000
116001.0400000.7468951.243800241.200000
117001.0286000.7466851.240900241.760000
118001.0146000.7468291.481100202.547000
119001.0175000.7470981.591900188.452000
120000.9999000.7458861.469700204.119000
121001.0304000.7457211.436700208.816000
122001.0396000.7459571.454700206.223000
123001.0232000.7456781.836300163.370000
124001.0146000.7454021.294600231.729000
125001.0335000.7448561.253300239.374000
126001.0322000.7453651.547900193.816000
127001.0187000.7451551.251700239.684000
128001.0163000.7440661.243200241.310000
129001.0369000.7436841.282300233.959000
130001.0036000.7442081.598900187.630000
131001.0176000.7446301.254700239.095000
132001.0209000.7456081.240800241.770000
133001.0039000.7450241.275700235.162000
134001.0447000.7442311.253400239.349000
135000.9952000.7442731.644500182.424000
136001.0196000.7437071.245900240.798000
137001.0122000.7441011.285800233.320000
138001.0020000.7435791.267500236.678000
139001.0242000.7431221.293900231.864000
140001.0141000.7441321.252500239.522000
141001.0279000.7435751.269800236.266000
142000.9938000.7432171.269300236.345000
143000.9856000.7432781.280700234.251000
144000.9896000.7432271.255900238.880000
145000.9776000.7434791.262200237.686000
146001.0004000.7429271.259000238.282000
147000.9934000.7417081.539800194.834000
148000.9938000.7413941.532100195.812000
149001.0001000.7404451.604900186.927000
150001.0208000.7402951.490500201.274000
151001.0008000.7409181.462000205.205000
152000.9757000.7406641.277300234.879000
153001.0135000.7399491.284700233.514000
154001.0078000.7398601.244100241.136000
155001.0246000.7400611.859200161.360000
156001.0152000.7401421.471400203.889000
157001.0068000.7403171.447000207.322000
158001.0078000.7402591.432000209.492000
159001.0045000.7404751.448100207.164000
160000.9719000.7403641.352500221.805000
161000.9799000.7404241.278300234.695000
162000.9903000.7402861.528500196.277000
163000.9927000.7407091.274100235.463000
164001.0006000.7400671.264900237.169000
165001.0040000.7405541.298000231.122000
166000.9874000.7397711.275400235.219000
167000.9657000.7394811.562800191.959000
168000.9993000.7398901.303900230.078000
169001.0142000.7395231.291700232.250000
170000.9816000.7403191.315600228.028000
171000.9904000.7396531.604000187.034000
172000.9906000.7398931.340100223.861000
173001.0015000.7399221.258800238.324000
174001.0160000.7391801.268600236.475000
175001.0136000.7396081.248700240.248000
176000.9840000.7397381.251400239.724000
177000.9761000.7394441.252400239.538000
178000.9823000.7387691.249200240.153000
179000.9893000.7387931.551400193.368000
180001.0123000.7387991.537700195.094000
181001.0381000.7391481.251200239.778000
182001.0100000.7393961.263800237.373000
183000.9939000.7402521.249700240.052000
184000.9784000.7399431.265800237.003000
185000.9986000.7396941.267100236.755000
186001.0129000.7387251.342900223.400000
187000.9919000.7386191.290600232.456000
188001.0180000.7384261.515000198.026000
189000.9826000.7388111.494000200.808000
190000.9542000.7385611.508100198.922000
191001.0181000.7390361.306000229.707000
192001.0078000.7387291.310700228.877000
193000.9885000.7388321.288000232.916000
194000.9893000.7386461.293900231.858000
195000.9899000.7381201.289200232.704000
196001.0045000.7380411.296600231.383000
197000.9735000.7382741.293300231.959000
198001.0215000.7382641.288700232.792000
199001.0201000.7381591.287900232.932000
200000.9881000.7388301.288400232.845000
201000.9889000.7389991.351100222.044000
202000.9467000.7391301.295800231.523000
203000.9852000.7383041.295200231.627000
204000.9762000.7389221.293200231.979000
205000.9726000.7383791.292200232.153000
206000.9773000.7382921.293500231.922000
207001.0141000.7385171.293500231.923000
208000.9869000.7387091.293300231.972000
209001.0045000.7378541.288800232.769000
210000.9962000.7381311.281100234.165000
211000.9948000.7377211.279400234.477000
212000.9904000.7380661.286100233.268000
213001.0011000.7381741.287700232.973000

" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "trainer.train()" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "

\n", + " \n", + " \n", + " \n", + " [38/38 00:01]\n", + "
\n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "{'eval_loss': 0.7315998077392578,\n", + " 'eval_runtime': 1.383,\n", + " 'eval_samples_per_second': 216.921,\n", + " 'epoch': 3.0}" + ] + }, + "execution_count": 28, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "trainer.evaluate()" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [], + "source": [ + "model.eval();" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "the Trump administration's internal policy is nonsense.\n", + "the internal policy of Donald Trump is nonsense.\n", + "the internal policy of President Trump is nonsense.\n", + "the Trump administration's internal policy is crazy.\n", + "the president's internal policy is nonsense.\n", + "the internal policy of Mr. Trump is nonsense.\n", + "the Trump administration's internal policy is stupid.\n", + "the internal policy of Trump is nonsense.\n", + "the Trump administration's internal policy is bad.\n", + "the Trump internal policy is nonsense.\n" + ] + } + ], + "source": [ + "inputs = tokenizer('The internal policy of the fucking Trump is stupid.', return_tensors='pt')\n", + "inputs = {k: v.to(device) for k, v in inputs.items()}\n", + "for t in model.generate(**inputs, num_return_sequences=10, do_sample=False, num_beams=10):\n", + " print(tokenizer.decode(t, skip_special_tokens=True))" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": {}, + "outputs": [], + "source": [ + "model.save_pretrained(save_name)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.6" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/emnlp2021/style_transfer/mining_parallel_corpus/paranmt_mining.ipynb b/emnlp2021/style_transfer/mining_parallel_corpus/paranmt_mining.ipynb new file mode 100644 index 0000000..409827a --- /dev/null +++ b/emnlp2021/style_transfer/mining_parallel_corpus/paranmt_mining.ipynb @@ -0,0 +1,1153 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Before starting the mining process, please download the corpus from https://www.cs.cmu.edu/~jwieting/ " + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "fn = '/home/dale/data/paraphrase_corpora/para-nmt-50m/para-nmt-50m.txt'" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import numpy as np\n", + "import csv\n", + "pd.options.display.max_colwidth = 150" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "os.environ['CUDA_VISIBLE_DEVICES'] = '4'" + ] + }, + { + "cell_type": "code", + "execution_count": 77, + "metadata": {}, + "outputs": [], + "source": [ + "df = pd.read_csv(fn, sep='\\t', header=None, nrows=1_000_000, encoding='utf-8', quoting=csv.QUOTE_NONE).dropna()\n", + "df.columns = ['reference', 'translation', 'similarity']" + ] + }, + { + "cell_type": "code", + "execution_count": 78, + "metadata": {}, + "outputs": [], + "source": [ + "def lenth_diff(row):\n", + " l1 = len(row.reference)\n", + " l2 = len(row.translation)\n", + " return np.abs(l1-l2) / (max(l1, l2) + 1)\n", + "\n", + "df['lenght_diff'] = df.apply(lenth_diff, axis=1)" + ] + }, + { + "cell_type": "code", + "execution_count": 87, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.7133114266228533\n", + "0.8754667509335019\n", + "0.8481736963473927\n", + "0.5371920743841487\n" + ] + } + ], + "source": [ + "print(np.mean((df.similarity > 0.6)))\n", + "print(np.mean(df.similarity <= 0.95))\n", + "print(np.mean(df.lenght_diff <= 0.4))\n", + "print(np.mean((df.similarity > 0.6) & (df.similarity <= 0.95) & (df.lenght_diff <= 0.4)))" + ] + }, + { + "cell_type": "code", + "execution_count": 88, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(1000, 4)" + ] + }, + "execution_count": 88, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "nonsim = df[(df.similarity > 0.6) & (df.similarity <= 0.95) & (df.lenght_diff <= 0.4)].head(1000)\n", + "nonsim.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 92, + "metadata": {}, + "outputs": [], + "source": [ + "from tqdm.auto import tqdm, trange\n", + "from transformers import RobertaTokenizer, RobertaForSequenceClassification\n", + "import torch" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model_name = 'SkolkovoInstitute/roberta_toxicity_classifier_v1'\n", + "\n", + "tokenizer = RobertaTokenizer.from_pretrained(model_name)\n", + "model = RobertaForSequenceClassification.from_pretrained(model_name)\n", + "\n", + "\n", + "def classify_preds(preds, batch_size=32, soft=True, threshold=0.5, soft=False):\n", + " single = False\n", + " if isinstance(preds, str):\n", + " preds = [preds]\n", + " single = True\n", + " results = []\n", + " \n", + " f = trange if verbose else range\n", + "\n", + " for i in f(0, len(preds), batch_size)):\n", + " batch = tokenizer(preds[i:i + batch_size], return_tensors='pt', padding=True)\n", + " with torch.inference_mode():\n", + " logits = model(**batch).logits\n", + " if soft:\n", + " result = torch.softmax(logits, -1)[:, 1].cpu().numpy()\n", + " else:\n", + " result = (logits[:, 1] > threshold).cpu().numpy()\n", + " results.extend([1 - item for item in result])\n", + " if single:\n", + " return np.mean(results)\n", + " return results" + ] + }, + { + "cell_type": "code", + "execution_count": 111, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "febffe55f3794fb484861b9a9b56e66c", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, max=26.0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/dale/p3/lib/python3.7/site-packages/pandas/core/frame.py:1490: FutureWarning: Using short name for 'orient' is deprecated. Only the options: ('dict', list, 'series', 'split', 'records', 'index') will be used in a future version. Use one of the above to silence this warning.\n", + " FutureWarning,\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/dale/p3/lib/python3.7/site-packages/ipykernel_launcher.py:1: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " \"\"\"Entry point for launching an IPython kernel.\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "d64463b92f054f37a19e07063796e5da", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, max=52.0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/dale/p3/lib/python3.7/site-packages/ipykernel_launcher.py:2: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " \n" + ] + } + ], + "source": [ + "nonsim['ref_tox'] = classify_preds(nonsim.reference.tolist(), verbose=True, batch_size=64)\n", + "nonsim['trn_tox'] = classify_preds(nonsim.translation.tolist(), verbose=True, batch_size=64)" + ] + }, + { + "cell_type": "code", + "execution_count": 95, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 95, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD4CAYAAAAXUaZHAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/d3fzzAAAACXBIWXMAAAsTAAALEwEAmpwYAAAPZElEQVR4nO3df6zd9V3H8ed7XAHhspbR2ZC27rKsmzZtVLiBLiTz3nUxpTOUREZY2ChLtdlkiDITqvtjRmOEPxgBQqaNLCum7sLqYhsG6izckC222g6kUJy7sMJaazugVMsPB/HtH+cD1nrbc3rvOefL/dznI7np98fne77v9z2X1/3ez/meQ2QmkqS6vKvpAiRJ3We4S1KFDHdJqpDhLkkVMtwlqUIDTRcAMG/evBwaGprSsa+88gpnn312dwt6h7Pn2cGeZ4fp9Lxr164XMvO9k+17R4T70NAQO3funNKx4+PjjIyMdLegdzh7nh3seXaYTs8R8dyJ9jktI0kVMtwlqUKGuyRVyHCXpAoZ7pJUIcNdkipkuEtShQx3SaqQ4S5JFXpHvEN1OnbvP8J167/VyLn33vLxRs4rSe145S5JFTLcJalChrskVchwl6QKGe6SVCHDXZIqZLhLUoUMd0mqkOEuSRUy3CWpQoa7JFXIcJekChnuklQhw12SKmS4S1KFDHdJqpDhLkkVMtwlqUKGuyRVyHCXpAoZ7pJUIcNdkipkuEtShQx3SaqQ4S5JFTLcJalChrskVaijcI+I34mIpyLiyYj4ekScGREXRMSOiJiIiPsi4vQy9oyyPlH2D/W0A0nS/9M23CNiAfBbwHBmLgVOA64GbgVuz8wPAIeBteWQtcDhsv32Mk6S1EedTssMAD8dEQPAWcAB4KPA5rJ/I3BFWV5d1in7V0REdKVaSVJHIjPbD4q4Efhj4DXg74Abge3l6pyIWAQ8lJlLI+JJYGVm7iv7ngEuycwXjnvMdcA6gPnz5180NjY2pQYOvXSEg69N6dBpW7ZgTiPnPXr0KIODg42cuyn2PDvY86kZHR3dlZnDk+0baHdwRJxL62r8AuBl4BvAyilVcozM3ABsABgeHs6RkZEpPc5dm7Zw2+62bfTE3mtGGjnv+Pg4U/1+zVT2PDvYc/d0Mi3zMeCHmfnjzHwD+CZwKTC3TNMALAT2l+X9wCKAsn8O8GJXq5YknVQn4f48sDwizipz5yuAPcAjwJVlzBpgS1neWtYp+x/OTuZ+JEld0zbcM3MHrRdGvwfsLsdsAG4GboqICeA84J5yyD3AeWX7TcD6HtQtSTqJjiarM/NLwJeO2/wscPEkY18HPjH90iRJU+U7VCWpQoa7JFXIcJekChnuklQhw12SKmS4S1KFDHdJqpDhLkkVMtwlqUKGuyRVyHCXpAoZ7pJUIcNdkipkuEtShQx3SaqQ4S5JFTLcJalChrskVchwl6QKGe6SVCHDXZIqZLhLUoUMd0mqkOEuSRUy3CWpQoa7JFXIcJekChnuklQhw12SKmS4S1KFDHdJqpDhLkkVMtwlqUKGuyRVyHCXpAp1FO4RMTciNkfEv0TE0xHx4Yh4T0R8OyJ+UP49t4yNiLgzIiYi4omIuLC3LUiSjtfplfsdwN9k5s8BvwA8DawHtmXmYmBbWQe4DFhcvtYBX+lqxZKkttqGe0TMAT4C3AOQmT/JzJeB1cDGMmwjcEVZXg3cmy3bgbkRcX6X65YknURk5skHRPwisAHYQ+uqfRdwI7A/M+eWMQEczsy5EfEAcEtmfqfs2wbcnJk7j3vcdbSu7Jk/f/5FY2NjU2rg0EtHOPjalA6dtmUL5jRy3qNHjzI4ONjIuZtiz7ODPZ+a0dHRXZk5PNm+gQ6OHwAuBG7IzB0RcQf/OwUDQGZmRJz8t8RxMnMDrV8aDA8P58jIyKkc/ra7Nm3htt2dtNF9e68ZaeS84+PjTPX7NVPZ8+xgz93TyZz7PmBfZu4o65tphf3Bt6Zbyr+Hyv79wKJjjl9YtkmS+qRtuGfmvwM/iogPlU0raE3RbAXWlG1rgC1leStwbblrZjlwJDMPdLdsSdLJdDqfcQOwKSJOB54FPkPrF8P9EbEWeA64qox9EFgFTACvlrGSpD7qKNwz83Fgskn7FZOMTeD66ZUlSZoO36EqSRUy3CWpQoa7JFXIcJekChnuklQhw12SKmS4S1KFDHdJqpDhLkkVMtwlqUKGuyRVyHCXpAoZ7pJUIcNdkipkuEtShQx3SaqQ4S5JFTLcJalChrskVchwl6QKGe6SVCHDXZIqZLhLUoUMd0mqkOEuSRUy3CWpQoa7JFXIcJekChnuklQhw12SKmS4S1KFDHdJqpDhLkkVMtwlqUKGuyRVqONwj4jTIuKxiHigrF8QETsiYiIi7ouI08v2M8r6RNk/1KPaJUkncCpX7jcCTx+zfitwe2Z+ADgMrC3b1wKHy/bbyzhJUh91FO4RsRD4OPDnZT2AjwKby5CNwBVleXVZp+xfUcZLkvokMrP9oIjNwJ8A5wC/C1wHbC9X50TEIuChzFwaEU8CKzNzX9n3DHBJZr5w3GOuA9YBzJ8//6KxsbEpNXDopSMcfG1Kh07bsgVzGjnv0aNHGRwcbOTcTbHn2cGeT83o6OiuzByebN9Au4Mj4leBQ5m5KyJGplTBJDJzA7ABYHh4OEdGpvbQd23awm2727bRE3uvGWnkvOPj40z1+zVT2fPsYM/d00kqXgpcHhGrgDOBdwN3AHMjYiAz3wQWAvvL+P3AImBfRAwAc4AXu165JOmE2s65Z+bvZebCzBwCrgYezsxrgEeAK8uwNcCWsry1rFP2P5ydzP1IkrpmOve53wzcFBETwHnAPWX7PcB5ZftNwPrplShJOlWnNFmdmePAeFl+Frh4kjGvA5/oQm2SpCnyHaqSVCHDXZIqZLhLUoUMd0mqkOEuSRUy3CWpQoa7JFXIcJekChnuklQhw12SKmS4S1KFDHdJqpDhLkkVMtwlqUKGuyRVyHCXpAoZ7pJUIcNdkipkuEtShQx3SaqQ4S5JFTLcJalChrskVchwl6QKGe6SVCHDXZIqZLhLUoUMd0mqkOEuSRUy3CWpQoa7JFXIcJekChnuklQhw12SKmS4S1KF2oZ7RCyKiEciYk9EPBURN5bt74mIb0fED8q/55btERF3RsRERDwRERf2uglJ0v/VyZX7m8AXMnMJsBy4PiKWAOuBbZm5GNhW1gEuAxaXr3XAV7petSTppNqGe2YeyMzvleX/BJ4GFgCrgY1l2EbgirK8Grg3W7YDcyPi/G4XLkk6scjMzgdHDAGPAkuB5zNzbtkewOHMnBsRDwC3ZOZ3yr5twM2ZufO4x1pH68qe+fPnXzQ2NjalBg69dISDr03p0GlbtmBOI+c9evQog4ODjZy7KfY8O9jzqRkdHd2VmcOT7Rvo9EEiYhD4K+C3M/M/WnnekpkZEZ3/lmgdswHYADA8PJwjIyOncvjb7tq0hdt2d9xGV+29ZqSR846PjzPV79dMZc+zgz13T0d3y0TET9EK9k2Z+c2y+eBb0y3l30Nl+35g0TGHLyzbJEl90sndMgHcAzydmV8+ZtdWYE1ZXgNsOWb7teWumeXAkcw80MWaJUltdDKfcSnwaWB3RDxetv0+cAtwf0SsBZ4Drir7HgRWARPAq8BnulmwJKm9tuFeXhiNE+xeMcn4BK6fZl2SpGnwHaqSVCHDXZIqZLhLUoUMd0mqkOEuSRUy3CWpQoa7JFXIcJekChnuklQhw12SKmS4S1KFDHdJqpDhLkkVMtwlqUKGuyRVyHCXpAoZ7pJUIcNdkipkuEtShQx3SaqQ4S5JFTLcJalChrskVchwl6QKGe6SVCHDXZIqZLhLUoUMd0mqkOEuSRUaaLoASWra0PpvNXbur608uyeP65W7JFXIcJekChnuklQhw12SKmS4S1KFehLuEbEyIr4fERMRsb4X55AknVjXwz0iTgPuBi4DlgCfjIgl3T6PJOnEenGf+8XARGY+CxARY8BqYE8PztWopu6N7dV9sZ1oqucvLHuT6xq8F7kJTT3PTd7zPRuf516JzOzuA0ZcCazMzF8v658GLsnMzx83bh2wrqx+CPj+FE85D3hhisfOVPY8O9jz7DCdnt+Xme+dbEdj71DNzA3Ahuk+TkTszMzhLpQ0Y9jz7GDPs0Oveu7FC6r7gUXHrC8s2yRJfdKLcP8nYHFEXBARpwNXA1t7cB5J0gl0fVomM9+MiM8DfwucBnw1M5/q9nmOMe2pnRnInmcHe54detJz119QlSQ1z3eoSlKFDHdJqtCMCfd2H2kQEWdExH1l/46IGGqgzK7qoOebImJPRDwREdsi4n1N1NlNnX50RUT8WkRkRMz42+Y66TkirirP9VMR8Zf9rrHbOvjZ/tmIeCQiHis/36uaqLNbIuKrEXEoIp48wf6IiDvL9+OJiLhw2ifNzHf8F60XZp8B3g+cDvwzsOS4Mb8J/GlZvhq4r+m6+9DzKHBWWf7cbOi5jDsHeBTYDgw3XXcfnufFwGPAuWX9Z5quuw89bwA+V5aXAHubrnuaPX8EuBB48gT7VwEPAQEsB3ZM95wz5cr97Y80yMyfAG99pMGxVgMby/JmYEVERB9r7La2PWfmI5n5alndTus9BTNZJ88zwB8BtwKv97O4Humk598A7s7MwwCZeajPNXZbJz0n8O6yPAf4tz7W13WZ+Sjw0kmGrAbuzZbtwNyIOH8655wp4b4A+NEx6/vKtknHZOabwBHgvL5U1xud9HystbR+889kbXsuf64uysxaPoCkk+f5g8AHI+K7EbE9Ilb2rbre6KTnPwA+FRH7gAeBG/pTWmNO9b/3tvwfZFcgIj4FDAO/3HQtvRQR7wK+DFzXcCn9NkBramaE1l9nj0bEssx8ucmieuyTwNcy87aI+DDwFxGxNDP/u+nCZoqZcuXeyUcavD0mIgZo/Sn3Yl+q642OPsYhIj4GfBG4PDP/q0+19Uq7ns8BlgLjEbGX1tzk1hn+omonz/M+YGtmvpGZPwT+lVbYz1Sd9LwWuB8gM/8BOJPWB2zVqusf2zJTwr2TjzTYCqwpy1cCD2d5pWKGattzRPwS8Ge0gn2mz8NCm54z80hmzsvMocwcovU6w+WZubOZcruik5/tv6Z11U5EzKM1TfNsH2vstk56fh5YARARP08r3H/c1yr7aytwbblrZjlwJDMPTOsRm34V+RRebV5F64rlGeCLZdsf0vqPG1pP/jeACeAfgfc3XXMfev574CDwePna2nTNve75uLHjzPC7ZTp8noPWdNQeYDdwddM196HnJcB3ad1J8zjwK03XPM1+vw4cAN6g9ZfYWuCzwGePeY7vLt+P3d34ufbjBySpQjNlWkaSdAoMd0mqkOEuSRUy3CWpQoa7JFXIcJekChnuklSh/wFzrK5YyWChsAAAAABJRU5ErkJggg==\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "nonsim.trn_tox.hist()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "* We filter sentences with moderate similarity (60% to 95%) and similar length (no more then 40% difference) - 50%% of corpus\n", + "* Out of them, 2% have estimated toxicity is very different between the source and the translation (20 out of 1000)\n", + "* Out of these 18:\n", + " * 2 are false paraphrases\n", + " * 8 in fact have comparable toxicity\n", + " * 10 are valid detoxifications" + ] + }, + { + "cell_type": "code", + "execution_count": 99, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.02\n", + "20\n" + ] + } + ], + "source": [ + "fltr = np.abs(nonsim.trn_tox-nonsim.ref_tox) > 0.5\n", + "print(np.mean(fltr))\n", + "print(np.sum(fltr))" + ] + }, + { + "cell_type": "code", + "execution_count": 100, + "metadata": {}, + "outputs": [], + "source": [ + "pd.options.display.max_colwidth = 500" + ] + }, + { + "cell_type": "code", + "execution_count": 101, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
referencetranslationsimilaritylenght_diffref_toxtrn_tox
0If Alkar is flooding her with psychic waste, that explains the high level of neurotransmitters.if Alkar floods her with her mental waste, it would explain the high levels of neurotransmitter.0.7851710.0103090.0141950.981983
1Now you're getting nasty.you're becoming disgusting.0.7496870.0714290.0654730.999039
2Well, we could spare your life, for one.well, we can spare your life.0.9190510.2682930.2133130.985068
3Ah! Monkey, you've got to snap out of it.monkey, you have to wake up.0.6643330.3095240.0533620.994215
4I've got orders to put her down.I have orders to kill her.0.7266390.1818180.0094020.999348
5I'm not gonna have a child... ...with the same genetic disorder as me who's gonna die. L...I'm not going to breed kids with a genetic disorder that makes them die.0.7031850.2065220.9509560.035846
6They're all laughing at us, so we'll kick your ass.they're laughing at us. We'll show you.0.6188660.2307690.9994920.000131
7Maine was very short on black people back then.there wasn't much black in Maine then.0.7204820.1875000.9636800.148710
8Briggs, what the hell's happening?Briggs, what the hell is going on?0.9203730.0000000.1590960.841071
9Another one simply had no clue what to do, so whenever he met my brother he'd beat the crap out of him, and then say:another simply didn't know what to do, so whenever he met my brother, he nearly beat the shit out of him.0.8775400.1016950.0553710.930472
10I suppose you want me to buy you flowers and chocolates and whisper sweet nothings.you'd probably want me to buy you some chocolates and flowers... and whispered some pretty rubbish.0.8006610.1600000.0000780.980341
11So now their spirits are cursed, walking back roads, waterways, and if they find an unfaithful man, they kill him, and that man is never seen again.their souls are cursed, they guard the paths, he says, and when they encounter an unfaithful man, he will be killed, and his body will never be found.0.7558830.0132450.8425090.143992
12Freezing him.I'll freeze him!0.7756460.1764710.0071860.573710
13Come on, Cal, leave that shit alone.come on, Cal, put it down.0.6604810.2702700.9996370.000279
14So he's the Top dog.he's the tallest son of a bitch.0.6110920.3636360.0009200.999639
15I swore when I went out with Xander Harris... ...I'd rather die than datea fixer-upper again.when I was dating Alex Harris, I swore I'd rather die than go out with a loser.0.7905650.1489360.0116130.996266
16I'm famous, and you're done.I'm famous, and you're dead.0.8172530.0000000.0009260.979738
17To quote Jake Oppenheimer: I, who am about to die, must seem to them something \"God-awful.\"...to quote Jake and Oppenheimer: \"I must die, I must feel like a terrible god.\"0.6995900.1789470.0023480.682655
18“Could you please be quiet, Miss Lavish?” said Moist.'could you keep your mouth shut, Miss Opulent? 'Said Moist.0.8089430.1000000.0001870.760356
19Murder for hire.murder to order.0.6976670.0000000.0745890.962326
\n", + "
" + ], + "text/plain": [ + " reference \\\n", + "0 If Alkar is flooding her with psychic waste, that explains the high level of neurotransmitters. \n", + "1 Now you're getting nasty. \n", + "2 Well, we could spare your life, for one. \n", + "3 Ah! Monkey, you've got to snap out of it. \n", + "4 I've got orders to put her down. \n", + "5 I'm not gonna have a child... ...with the same genetic disorder as me who's gonna die. L... \n", + "6 They're all laughing at us, so we'll kick your ass. \n", + "7 Maine was very short on black people back then. \n", + "8 Briggs, what the hell's happening? \n", + "9 Another one simply had no clue what to do, so whenever he met my brother he'd beat the crap out of him, and then say: \n", + "10 I suppose you want me to buy you flowers and chocolates and whisper sweet nothings. \n", + "11 So now their spirits are cursed, walking back roads, waterways, and if they find an unfaithful man, they kill him, and that man is never seen again. \n", + "12 Freezing him. \n", + "13 Come on, Cal, leave that shit alone. \n", + "14 So he's the Top dog. \n", + "15 I swore when I went out with Xander Harris... ...I'd rather die than datea fixer-upper again. \n", + "16 I'm famous, and you're done. \n", + "17 To quote Jake Oppenheimer: I, who am about to die, must seem to them something \"God-awful.\"... \n", + "18 “Could you please be quiet, Miss Lavish?” said Moist. \n", + "19 Murder for hire. \n", + "\n", + " translation \\\n", + "0 if Alkar floods her with her mental waste, it would explain the high levels of neurotransmitter. \n", + "1 you're becoming disgusting. \n", + "2 well, we can spare your life. \n", + "3 monkey, you have to wake up. \n", + "4 I have orders to kill her. \n", + "5 I'm not going to breed kids with a genetic disorder that makes them die. \n", + "6 they're laughing at us. We'll show you. \n", + "7 there wasn't much black in Maine then. \n", + "8 Briggs, what the hell is going on? \n", + "9 another simply didn't know what to do, so whenever he met my brother, he nearly beat the shit out of him. \n", + "10 you'd probably want me to buy you some chocolates and flowers... and whispered some pretty rubbish. \n", + "11 their souls are cursed, they guard the paths, he says, and when they encounter an unfaithful man, he will be killed, and his body will never be found. \n", + "12 I'll freeze him! \n", + "13 come on, Cal, put it down. \n", + "14 he's the tallest son of a bitch. \n", + "15 when I was dating Alex Harris, I swore I'd rather die than go out with a loser. \n", + "16 I'm famous, and you're dead. \n", + "17 to quote Jake and Oppenheimer: \"I must die, I must feel like a terrible god.\" \n", + "18 'could you keep your mouth shut, Miss Opulent? 'Said Moist. \n", + "19 murder to order. \n", + "\n", + " similarity lenght_diff ref_tox trn_tox \n", + "0 0.785171 0.010309 0.014195 0.981983 \n", + "1 0.749687 0.071429 0.065473 0.999039 \n", + "2 0.919051 0.268293 0.213313 0.985068 \n", + "3 0.664333 0.309524 0.053362 0.994215 \n", + "4 0.726639 0.181818 0.009402 0.999348 \n", + "5 0.703185 0.206522 0.950956 0.035846 \n", + "6 0.618866 0.230769 0.999492 0.000131 \n", + "7 0.720482 0.187500 0.963680 0.148710 \n", + "8 0.920373 0.000000 0.159096 0.841071 \n", + "9 0.877540 0.101695 0.055371 0.930472 \n", + "10 0.800661 0.160000 0.000078 0.980341 \n", + "11 0.755883 0.013245 0.842509 0.143992 \n", + "12 0.775646 0.176471 0.007186 0.573710 \n", + "13 0.660481 0.270270 0.999637 0.000279 \n", + "14 0.611092 0.363636 0.000920 0.999639 \n", + "15 0.790565 0.148936 0.011613 0.996266 \n", + "16 0.817253 0.000000 0.000926 0.979738 \n", + "17 0.699590 0.178947 0.002348 0.682655 \n", + "18 0.808943 0.100000 0.000187 0.760356 \n", + "19 0.697667 0.000000 0.074589 0.962326 " + ] + }, + "execution_count": 101, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "nonsim[fltr].reset_index(drop=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Large scale fine tuning" + ] + }, + { + "cell_type": "code", + "execution_count": 102, + "metadata": {}, + "outputs": [], + "source": [ + "chunksize = 3_000" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "df = pd.read_csv(fn, sep='\\t', header=None, nrows=1_000_000, encoding='utf-8', quoting=csv.QUOTE_NONE).dropna()\n", + "df.columns = ['reference', 'translation', 'similarity']" + ] + }, + { + "cell_type": "code", + "execution_count": 123, + "metadata": {}, + "outputs": [], + "source": [ + "import warnings\n", + "warnings.simplefilter(action='ignore', category=FutureWarning)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "\n", + "* 50M / 3K = 16K iterations\n", + "* with 1 minute/iteration, the job will take 11 days, but with free GPUs this is faster" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "3803079445c04b6186b32195ac10fa9f", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "380 samples\n", + "689 samples\n", + "1025 samples\n", + "1353 samples\n", + "1663 samples\n", + "2012 samples\n", + "2353 samples\n", + "2696 samples\n", + "3033 samples\n", + "3372 samples\n", + "3709 samples\n", + "4040 samples\n", + "4385 samples\n", + "4702 samples\n", + "5046 samples\n", + "5405 samples\n", + "5742 samples\n", + "6046 samples\n", + "6374 samples\n", + "6698 samples\n", + "7042 samples\n", + "7386 samples\n", + "7718 samples\n", + "8015 samples\n", + "8374 samples\n", + "8679 samples\n", + "9031 samples\n", + "9371 samples\n", + "9735 samples\n", + "10066 samples\n", + "10385 samples\n", + "10728 samples\n", + "11046 samples\n", + "11395 samples\n", + "11748 samples\n", + "12081 samples\n", + "12414 samples\n" + ] + } + ], + "source": [ + "results = []\n", + "\n", + "for i, chunk in enumerate(tqdm(pd.read_csv(fn, sep='\\t', header=None, encoding='utf-8', quoting=csv.QUOTE_NONE, chunksize=chunksize))):\n", + " chunk.dropna(inplace=True)\n", + " chunk.columns = ['reference', 'translation', 'similarity']\n", + " chunk['lenght_diff'] = chunk.apply(lenth_diff, axis=1)\n", + " nonsim = chunk[(chunk.similarity > 0.6) & (chunk.similarity <= 0.95) & (chunk.lenght_diff <= 0.4)].copy()\n", + " \n", + " nonsim['ref_tox'] = classify_preds(nonsim.reference.tolist(), verbose=False, batch_size=64)\n", + " nonsim['trn_tox'] = classify_preds(nonsim.translation.tolist(), verbose=False, batch_size=64)\n", + " \n", + " fltr = np.abs(nonsim.trn_tox-nonsim.ref_tox) > 0.5\n", + " mined = nonsim[fltr]\n", + " results.append(mined)\n", + " # print(nonsim.shape[0], mined.shape[0])\n", + " if i > 0 and i % 10 == 0:\n", + " res_df = pd.concat(results, ignore_index=True)\n", + " print(res_df.shape[0], 'samples')\n", + " res_df.to_csv('filtered.tsv', sep='\\t', encoding='utf-8')" + ] + }, + { + "cell_type": "code", + "execution_count": 127, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "1" + ] + }, + "execution_count": 127, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "1" + ] + }, + { + "cell_type": "code", + "execution_count": 128, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "1" + ] + }, + "execution_count": 128, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "1" + ] + }, + { + "cell_type": "code", + "execution_count": 129, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(577988, 6)" + ] + }, + "execution_count": 129, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "res_df = pd.concat(results, ignore_index=True)\n", + "res_df.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 130, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
similaritylenght_diffref_toxtrn_tox
count577988.000000577988.000000577988.000000577988.000000
mean0.7584660.1576540.5413820.434479
std0.0926960.1080560.4575690.458904
min0.6000010.0000000.0000330.000033
25%0.6810990.0666670.0121720.000707
50%0.7544360.1418180.8068340.085110
75%0.8312410.2380950.9904690.973742
max0.9500000.4000000.9997240.999730
\n", + "
" + ], + "text/plain": [ + " similarity lenght_diff ref_tox trn_tox\n", + "count 577988.000000 577988.000000 577988.000000 577988.000000\n", + "mean 0.758466 0.157654 0.541382 0.434479\n", + "std 0.092696 0.108056 0.457569 0.458904\n", + "min 0.600001 0.000000 0.000033 0.000033\n", + "25% 0.681099 0.066667 0.012172 0.000707\n", + "50% 0.754436 0.141818 0.806834 0.085110\n", + "75% 0.831241 0.238095 0.990469 0.973742\n", + "max 0.950000 0.400000 0.999724 0.999730" + ] + }, + "execution_count": 130, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "res_df.describe()" + ] + }, + { + "cell_type": "code", + "execution_count": 131, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
referencetranslationsimilaritylenght_diffref_toxtrn_tox
88294I just beat up three guys by myself.I just cut three big men.0.6856850.2972970.7509380.005162
402968Whoo! Welcome back, mad dog!welcome back, Fury!0.7353530.3103450.9984120.000081
396420The roof is scraping my nut!the roof is cutting off my balls!0.7461090.1470590.0124070.993682
344928Blasted fairy music.to hell with the music.0.6675270.1250000.0014860.932651
57239A kidnapper and a child murderer, yes.the kidnapper and the child killer, yes.0.8885920.0487800.9677700.121575
444506For awhile, I went to sex clubs.I went to sex clubs for a while.0.9235400.0000000.1476750.966582
22691When I released her she bites you.if I leave her, she'll bite you.0.6750120.0571430.0064790.979373
242403If you don't at least try, Alex, you're already a loser.if you don't try, Alex, you lost in advance.0.7278350.2105260.9981640.000176
179191'I know you're only the humorous thugs.\"I'm sure you're only a couple of comic thieves.0.6217040.1836730.9240460.265508
541902You know as well as I do what a monster he was.you know as much as I did for the monster.0.7009370.1041670.9955590.002968
\n", + "
" + ], + "text/plain": [ + " reference \\\n", + "88294 I just beat up three guys by myself. \n", + "402968 Whoo! Welcome back, mad dog! \n", + "396420 The roof is scraping my nut! \n", + "344928 Blasted fairy music. \n", + "57239 A kidnapper and a child murderer, yes. \n", + "444506 For awhile, I went to sex clubs. \n", + "22691 When I released her she bites you. \n", + "242403 If you don't at least try, Alex, you're already a loser. \n", + "179191 'I know you're only the humorous thugs. \n", + "541902 You know as well as I do what a monster he was. \n", + "\n", + " translation similarity \\\n", + "88294 I just cut three big men. 0.685685 \n", + "402968 welcome back, Fury! 0.735353 \n", + "396420 the roof is cutting off my balls! 0.746109 \n", + "344928 to hell with the music. 0.667527 \n", + "57239 the kidnapper and the child killer, yes. 0.888592 \n", + "444506 I went to sex clubs for a while. 0.923540 \n", + "22691 if I leave her, she'll bite you. 0.675012 \n", + "242403 if you don't try, Alex, you lost in advance. 0.727835 \n", + "179191 \"I'm sure you're only a couple of comic thieves. 0.621704 \n", + "541902 you know as much as I did for the monster. 0.700937 \n", + "\n", + " lenght_diff ref_tox trn_tox \n", + "88294 0.297297 0.750938 0.005162 \n", + "402968 0.310345 0.998412 0.000081 \n", + "396420 0.147059 0.012407 0.993682 \n", + "344928 0.125000 0.001486 0.932651 \n", + "57239 0.048780 0.967770 0.121575 \n", + "444506 0.000000 0.147675 0.966582 \n", + "22691 0.057143 0.006479 0.979373 \n", + "242403 0.210526 0.998164 0.000176 \n", + "179191 0.183673 0.924046 0.265508 \n", + "541902 0.104167 0.995559 0.002968 " + ] + }, + "execution_count": 131, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "res_df.sample(10)" + ] + }, + { + "cell_type": "code", + "execution_count": 132, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 132, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYkAAAD4CAYAAAAZ1BptAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/d3fzzAAAACXBIWXMAAAsTAAALEwEAmpwYAAAVWUlEQVR4nO3df7DldX3f8ecrixiHREExdxiWunTcNF2lQd0BOuk0tzKFC53JYooOzFQWQ9xMhTbpbDuu6R9Ylan+YZihVdK17LA4RqQklm1YSxnkjpNOQDAgsFDDDWrZDUpkAbM6ate++8f5bHO83s+9h/vj3L3c52PmzP2e9/fz+X4+Hy7c1/3+uIdUFZIkzeVnVnsCkqTjlyEhSeoyJCRJXYaEJKnLkJAkdZ2w2hNYbqeeempt2rRpUX2/973vcdJJJy3vhI5zrnl9cM3rw1LW/JWvfOU7VfX62fWXXUhs2rSJBx98cFF9p6enmZycXN4JHedc8/rgmteHpaw5yTfnqnu5SZLUZUhIkroMCUlSlyEhSeoyJCRJXYaEJKnLkJAkdRkSkqQuQ0KS1PWy+4trSVpNm3bduWpj3zy1/B9D4pmEJKnLkJAkdRkSkqQuQ0KS1GVISJK6DAlJUpchIUnqMiQkSV2GhCSpy5CQJHUZEpKkLkNCktRlSEiSuhYMiSQ/m+TLSb6a5ECSf9fqZya5P8lMks8lObHVX9nez7T9m4aO9YFW/1qSC4fqU602k2TXUH3OMSRJ4zHKmcQPgbdX1S8DZwNTSc4DPgZcX1VvBJ4HrmrtrwKeb/XrWzuSbAEuA94ETAGfTLIhyQbgE8BFwBbg8taWecaQJI3BgiFRA0fa21e0VwFvB25v9b3AJW17W3tP239+krT6rVX1w6r6OjADnNNeM1X1VFX9CLgV2Nb69MaQJI3BSP/Tofbb/leANzL4rf8vgBeq6mhrchA4vW2fDjwNUFVHk7wIvK7V7xs67HCfp2fVz219emPMnt8OYAfAxMQE09PToyzrpxw5cmTRfdcq17w+uObx2XnW0YUbrZCVWPNIIVFVPwbOTnIy8Hngl5Z1FktUVbuB3QBbt26tycnJRR1nenqaxfZdq1zz+uCax+fKVf4/0y33ml/S001V9QJwL/D3gZOTHAuZjcChtn0IOAOg7X8N8NxwfVafXv25ecaQJI3BKE83vb6dQZDkVcA/Bp5gEBaXtmbbgTva9r72nrb/i1VVrX5Ze/rpTGAz8GXgAWBze5LpRAY3t/e1Pr0xJEljMMrlptOAve2+xM8At1XVHyd5HLg1yUeAh4CbWvubgE8nmQEOM/ihT1UdSHIb8DhwFLi6XcYiyTXAXcAGYE9VHWjHen9nDEnSGCwYElX1CPCWOepPMXgyaXb9B8A7O8e6Drhujvp+YP+oY0iSxsO/uJYkdRkSkqQuQ0KS1GVISJK6DAlJUpchIUnqMiQkSV2GhCSpy5CQJHUZEpKkLkNCktRlSEiSugwJSVKXISFJ6jIkJEldhoQkqcuQkCR1GRKSpC5DQpLUZUhIkroMCUlSlyEhSepaMCSSnJHk3iSPJzmQ5Ldb/YNJDiV5uL0uHurzgSQzSb6W5MKh+lSrzSTZNVQ/M8n9rf65JCe2+ivb+5m2f9Oyrl6SNK9RziSOAjuragtwHnB1ki1t3/VVdXZ77Qdo+y4D3gRMAZ9MsiHJBuATwEXAFuDyoeN8rB3rjcDzwFWtfhXwfKtf39pJksZkwZCoqmeq6s/a9l8DTwCnz9NlG3BrVf2wqr4OzADntNdMVT1VVT8CbgW2JQnwduD21n8vcMnQsfa27duB81t7SdIYvKR7Eu1yz1uA+1vpmiSPJNmT5JRWOx14eqjbwVbr1V8HvFBVR2fVf+JYbf+Lrb0kaQxOGLVhkp8D/hD4nar6bpIbgQ8D1b5+HPiNFZnlwnPbAewAmJiYYHp6elHHOXLkyKL7rlWueX1wzeOz86yjCzdaISux5pFCIskrGATEZ6rqjwCq6ttD+z8F/HF7ewg4Y6j7xlajU38OODnJCe1sYbj9sWMdTHIC8JrW/idU1W5gN8DWrVtrcnJylGX9lOnpaRbbd61yzeuDax6fK3fdOfYxj7l56qRlX/MoTzcFuAl4oqp+b6h+2lCzdwCPte19wGXtyaQzgc3Al4EHgM3tSaYTGdzc3ldVBdwLXNr6bwfuGDrW9rZ9KfDF1l6SNAajnEn8CvBu4NEkD7fa7zJ4OulsBpebvgH8FkBVHUhyG/A4gyejrq6qHwMkuQa4C9gA7KmqA+147wduTfIR4CEGoUT7+ukkM8BhBsEiSRqTBUOiqv4EmOuJov3z9LkOuG6O+v65+lXVUwyefppd/wHwzoXmKElaGf7FtSSpy5CQJHUZEpKkLkNCktRlSEiSugwJSVKXISFJ6jIkJEldhoQkqcuQkCR1GRKSpC5DQpLUZUhIkroMCUlSlyEhSeoyJCRJXYaEJKnLkJAkdRkSkqQuQ0KS1GVISJK6DAlJUpchIUnqWjAkkpyR5N4kjyc5kOS3W/21Se5O8mT7ekqrJ8kNSWaSPJLkrUPH2t7aP5lk+1D9bUkebX1uSJL5xpAkjccoZxJHgZ1VtQU4D7g6yRZgF3BPVW0G7mnvAS4CNrfXDuBGGPzAB64FzgXOAa4d+qF/I/DeoX5Trd4bQ5I0BguGRFU9U1V/1rb/GngCOB3YBuxtzfYCl7TtbcAtNXAfcHKS04ALgbur6nBVPQ/cDUy1fa+uqvuqqoBbZh1rrjEkSWNwwktpnGQT8BbgfmCiqp5pu74FTLTt04Gnh7odbLX56gfnqDPPGLPntYPBWQsTExNMT0+/lGX9f0eOHFl037XKNa8Prnl8dp51dOxjHrMSax45JJL8HPCHwO9U1XfbbQMAqqqS1LLObJb5xqiq3cBugK1bt9bk5OSixpienmaxfdcq17w+uObxuXLXnWMf85ibp05a9jWP9HRTklcwCIjPVNUftfK326Ui2tdnW/0QcMZQ942tNl994xz1+caQJI3BKE83BbgJeKKqfm9o1z7g2BNK24E7hupXtKeczgNebJeM7gIuSHJKu2F9AXBX2/fdJOe1sa6Yday5xpAkjcEol5t+BXg38GiSh1vtd4GPArcluQr4JvCutm8/cDEwA3wfeA9AVR1O8mHggdbuQ1V1uG2/D7gZeBXwhfZinjEkSWOwYEhU1Z8A6ew+f472BVzdOdYeYM8c9QeBN89Rf26uMSRJ4+FfXEuSugwJSVKXISFJ6jIkJEldhoQkqcuQkCR1GRKSpC5DQpLUZUhIkroMCUlSlyEhSeoyJCRJXYaEJKnLkJAkdRkSkqQuQ0KS1GVISJK6DAlJUpchIUnqMiQkSV2GhCSpy5CQJHUZEpKkrgVDIsmeJM8meWyo9sEkh5I83F4XD+37QJKZJF9LcuFQfarVZpLsGqqfmeT+Vv9ckhNb/ZXt/Uzbv2nZVi1JGskoZxI3A1Nz1K+vqrPbaz9Aki3AZcCbWp9PJtmQZAPwCeAiYAtweWsL8LF2rDcCzwNXtfpVwPOtfn1rJ0kaowVDoqq+BBwe8XjbgFur6odV9XVgBjinvWaq6qmq+hFwK7AtSYC3A7e3/nuBS4aOtbdt3w6c39pLksbkhCX0vSbJFcCDwM6qeh44HbhvqM3BVgN4elb9XOB1wAtVdXSO9qcf61NVR5O82Np/Z/ZEkuwAdgBMTEwwPT29qAUdOXJk0X3XKte8Prjm8dl51tGFG62QlVjzYkPiRuDDQLWvHwd+Y7km9VJV1W5gN8DWrVtrcnJyUceZnp5msX3XKte8Prjm8bly151jH/OYm6dOWvY1L+rppqr6dlX9uKr+L/ApBpeTAA4BZww13dhqvfpzwMlJTphV/4ljtf2vae0lSWOyqJBIctrQ23cAx5582gdc1p5MOhPYDHwZeADY3J5kOpHBze19VVXAvcClrf924I6hY21v25cCX2ztJUljsuDlpiSfBSaBU5McBK4FJpOczeBy0zeA3wKoqgNJbgMeB44CV1fVj9txrgHuAjYAe6rqQBvi/cCtST4CPATc1Oo3AZ9OMsPgxvllS12sJOmlWTAkquryOco3zVE71v464Lo56vuB/XPUn+JvLlcN138AvHOh+UmSVo5/cS1J6jIkJEldhoQkqcuQkCR1GRKSpC5DQpLUZUhIkroMCUlSlyEhSeoyJCRJXYaEJKnLkJAkdRkSkqQuQ0KS1GVISJK6DAlJUpchIUnqMiQkSV2GhCSpy5CQJHUZEpKkLkNCktS1YEgk2ZPk2SSPDdVem+TuJE+2r6e0epLckGQmySNJ3jrUZ3tr/2SS7UP1tyV5tPW5IUnmG0OSND6jnEncDEzNqu0C7qmqzcA97T3ARcDm9toB3AiDH/jAtcC5wDnAtUM/9G8E3jvUb2qBMSRJY7JgSFTVl4DDs8rbgL1tey9wyVD9lhq4Dzg5yWnAhcDdVXW4qp4H7gam2r5XV9V9VVXALbOONdcYkqQxWew9iYmqeqZtfwuYaNunA08PtTvYavPVD85Rn28MSdKYnLDUA1RVJanlmMxix0iyg8HlLSYmJpienl7UOEeOHFl037XKNa8Prnl8dp51dOxjHrMSa15sSHw7yWlV9Uy7ZPRsqx8Czhhqt7HVDgGTs+rTrb5xjvbzjfFTqmo3sBtg69atNTk52Ws6r+npaRbbd61yzeuDax6fK3fdOfYxj7l56qRlX/NiLzftA449obQduGOofkV7yuk84MV2yegu4IIkp7Qb1hcAd7V9301yXnuq6YpZx5prDEnSmCx4JpHkswzOAk5NcpDBU0ofBW5LchXwTeBdrfl+4GJgBvg+8B6Aqjqc5MPAA63dh6rq2M3w9zF4gupVwBfai3nGkCSNyYIhUVWXd3adP0fbAq7uHGcPsGeO+oPAm+eoPzfXGJKk8fEvriVJXYaEJKnLkJAkdRkSkqQuQ0KS1GVISJK6DAlJUteSP7tJko5Hjx56cVU/IuPlwjMJSVKXISFJ6jIkJEldhoQkqcuQkCR1GRKSpC5DQpLUZUhIkroMCUlSlyEhSeoyJCRJXYaEJKnLkJAkdRkSkqQuQ0KS1LWkkEjyjSSPJnk4yYOt9tokdyd5sn09pdWT5IYkM0keSfLWoeNsb+2fTLJ9qP62dvyZ1jdLma8k6aVZjjOJf1RVZ1fV1vZ+F3BPVW0G7mnvAS4CNrfXDuBGGIQKcC1wLnAOcO2xYGlt3jvUb2oZ5itJGtFKXG7aBuxt23uBS4bqt9TAfcDJSU4DLgTurqrDVfU8cDcw1fa9uqruq6oCbhk6liRpDJb6vy8t4H8kKeA/VdVuYKKqnmn7vwVMtO3TgaeH+h5stfnqB+eo/5QkOxicnTAxMcH09PSiFnPkyJFF912rXPP6sB7XPPEq2HnW0dWexlitxPd5qSHxD6rqUJJfAO5O8r+Gd1ZVtQBZUS2cdgNs3bq1JicnF3Wc6elpFtt3rXLN68N6XPN/+MwdfPzRpf6IW1tunjpp2b/PS7rcVFWH2tdngc8zuKfw7XapiPb12db8EHDGUPeNrTZffeMcdUnSmCw6JJKclOTnj20DFwCPAfuAY08obQfuaNv7gCvaU07nAS+2y1J3ARckOaXdsL4AuKvt+26S89pTTVcMHUuSNAZLORebAD7fnko9AfiDqvrvSR4AbktyFfBN4F2t/X7gYmAG+D7wHoCqOpzkw8ADrd2Hqupw234fcDPwKuAL7SVJGpNFh0RVPQX88hz154Dz56gXcHXnWHuAPXPUHwTevNg5SpKWZn3d1ZE0Vpt23blqY+88a9WGflnxYzkkSV2GhCSpy5CQJHUZEpKkLkNCktRlSEiSugwJSVKXfychrQOPHnqRK1fxbxa0dnkmIUnqMiQkSV2GhCSpy3sS0pj4OUZaizyTkCR1eSahdccnfaTRGRJaFV56kdYGQ2LIav6G+Y2P/pNVGdffqiXNx5A4TqzWb9b+Vi1pPt64liR1GRKSpC5DQpLUZUhIkroMCUlS13EfEkmmknwtyUySXas9H0laT47rkEiyAfgEcBGwBbg8yZbVnZUkrR/HdUgA5wAzVfVUVf0IuBXYtspzkqR1I1W12nPoSnIpMFVVv9nevxs4t6qumdVuB7Cjvf07wNcWOeSpwHcW2Xetcs3rg2teH5ay5jdU1etnF18Wf3FdVbuB3Us9TpIHq2rrMkxpzXDN64NrXh9WYs3H++WmQ8AZQ+83tpokaQyO95B4ANic5MwkJwKXAftWeU6StG4c15ebqupokmuAu4ANwJ6qOrCCQy75ktUa5JrXB9e8Piz7mo/rG9eSpNV1vF9ukiStIkNCktS1LkNioY/6SHJlkr9K8nB7/eZqzHM5jfLxJkneleTxJAeS/MG457jcRvg+Xz/0Pf7zJC+swjSX1Qhr/ltJ7k3yUJJHkly8GvNcTiOs+Q1J7mnrnU6ycTXmuVyS7EnybJLHOvuT5Ib2z+ORJG9d0oBVta5eDG6A/wXwt4ETga8CW2a1uRL4j6s91zGveTPwEHBKe/8Lqz3vlV7zrPb/gsGDEas+9xX+Pu8G/nnb3gJ8Y7XnPYY1/xdge9t+O/Dp1Z73Etf8D4G3Ao919l8MfAEIcB5w/1LGW49nEuvxoz5GWfN7gU9U1fMAVfXsmOe43F7q9/ly4LNjmdnKGWXNBby6bb8G+Msxzm8ljLLmLcAX2/a9c+xfU6rqS8DheZpsA26pgfuAk5Octtjx1mNInA48PfT+YKvN9k/bqdrtSc6YY/9aMsqafxH4xST/M8l9SabGNruVMer3mSRvAM7kb36QrFWjrPmDwD9LchDYz+AMai0bZc1fBX69bb8D+PkkrxvD3FbLyP/uj2I9hsQo/huwqar+HnA3sHeV5zMOJzC45DTJ4LfqTyU5eTUnNEaXAbdX1Y9XeyJjcDlwc1VtZHBZ4tNJXu4/B/418KtJHgJ+lcGnNqyH7/WyeLn/yzGXBT/qo6qeq6oftrf/GXjbmOa2Ukb5eJODwL6q+j9V9XXgzxmExlr1Uj7S5TLW/qUmGG3NVwG3AVTVnwI/y+BD4daqUf57/suq+vWqegvwb1vthbHNcPyW9eOM1mNILPhRH7Ou3/0a8MQY57cSRvl4k//K4CyCJKcyuPz01BjnuNxG+kiXJL8EnAL86ZjntxJGWfP/Bs4HSPJ3GYTEX411lstrlP+eTx06W/oAsGfMcxy3fcAV7Smn84AXq+qZxR7suP5YjpVQnY/6SPIh4MGq2gf8yyS/BhxlcIPoylWb8DIYcc13ARckeZzBqfi/qarnVm/WSzPimmHwQ+XWao+FrGUjrnkng0uJ/4rBTewr1/LaR1zzJPDvkxTwJeDqVZvwMkjyWQZrOrXdW7oWeAVAVf0+g3tNFwMzwPeB9yxpvDX874ckaYWtx8tNkqQRGRKSpC5DQpLUZUhIkroMCUlSlyEhSeoyJCRJXf8Ppx/KUOUF4zMAAAAASUVORK5CYII=\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "res_df['tox_diff'] = np.abs(res_df.ref_tox - res_df.trn_tox)\n", + "res_df.tox_diff.hist()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.6" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} \ No newline at end of file