diff --git a/b06504102_hw7/hw7.ipynb b/b06504102_hw7/hw7.ipynb new file mode 100644 index 0000000..587e9de --- /dev/null +++ b/b06504102_hw7/hw7.ipynb @@ -0,0 +1,1062 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "accelerator": "GPU", + "colab": { + "name": "「hw7_bert」的副本", + "provenance": [], + "collapsed_sections": [], + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "2ed6ef4de01843598855909b4209480e": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HBoxModel", + "state": { + "_view_name": "HBoxView", + "_dom_classes": [], + "_model_name": "HBoxModel", + "_view_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_view_count": null, + "_view_module_version": "1.5.0", + "box_style": "", + "layout": "IPY_MODEL_f2fd2cfef1714f4ca3c4a19d4f4c8029", + "_model_module": "@jupyter-widgets/controls", + "children": [ + "IPY_MODEL_1092a4b703aa429bb04b902424e88667", + "IPY_MODEL_e31d3c7a5edc4ae49f3f2ee1aad98858" + ] + } + }, + "f2fd2cfef1714f4ca3c4a19d4f4c8029": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "state": { + "_view_name": "LayoutView", + "grid_template_rows": null, + "right": null, + "justify_content": null, + "_view_module": "@jupyter-widgets/base", + "overflow": null, + "_model_module_version": "1.2.0", + "_view_count": null, + "flex_flow": null, + "width": null, + "min_width": null, + "border": null, + "align_items": null, + "bottom": null, + "_model_module": "@jupyter-widgets/base", + "top": null, + "grid_column": null, + "overflow_y": null, + "overflow_x": null, + "grid_auto_flow": null, + "grid_area": null, + "grid_template_columns": null, + "flex": null, + "_model_name": "LayoutModel", + "justify_items": null, + "grid_row": null, + "max_height": null, + "align_content": null, + "visibility": null, + "align_self": null, + "height": null, + "min_height": null, + "padding": null, + "grid_auto_rows": null, + "grid_gap": null, + "max_width": null, + "order": null, + "_view_module_version": "1.2.0", + "grid_template_areas": null, + "object_position": null, + "object_fit": null, + "grid_auto_columns": null, + "margin": null, + "display": null, + "left": null + } + }, + "1092a4b703aa429bb04b902424e88667": { + "model_module": "@jupyter-widgets/controls", + "model_name": "FloatProgressModel", + "state": { + "_view_name": "ProgressView", + "style": "IPY_MODEL_cabe2e3a668f49959005b8a8f1600e4c", + "_dom_classes": [], + "description": " 0%", + "_model_name": "FloatProgressModel", + "bar_style": "danger", + "max": 1684, + "_view_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "value": 1, + "_view_count": null, + "_view_module_version": "1.5.0", + "orientation": "horizontal", + "min": 0, + "description_tooltip": null, + "_model_module": "@jupyter-widgets/controls", + "layout": "IPY_MODEL_847745386d86408a8ca243418a08c2a2" + } + }, + "e31d3c7a5edc4ae49f3f2ee1aad98858": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "state": { + "_view_name": "HTMLView", + "style": "IPY_MODEL_a590054b6ff542189409d24050743b51", + "_dom_classes": [], + "description": "", + "_model_name": "HTMLModel", + "placeholder": "​", + "_view_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "value": " 1/1684 [00:00<08:59, 3.12it/s]", + "_view_count": null, + "_view_module_version": "1.5.0", + "description_tooltip": null, + "_model_module": "@jupyter-widgets/controls", + "layout": "IPY_MODEL_6f376e003f274abface1e0f063b06c63" + } + }, + "cabe2e3a668f49959005b8a8f1600e4c": { + "model_module": "@jupyter-widgets/controls", + "model_name": "ProgressStyleModel", + "state": { + "_view_name": "StyleView", + "_model_name": "ProgressStyleModel", + "description_width": "initial", + "_view_module": "@jupyter-widgets/base", + "_model_module_version": "1.5.0", + "_view_count": null, + "_view_module_version": "1.2.0", + "bar_color": null, + "_model_module": "@jupyter-widgets/controls" + } + }, + "847745386d86408a8ca243418a08c2a2": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "state": { + "_view_name": "LayoutView", + "grid_template_rows": null, + "right": null, + "justify_content": null, + "_view_module": "@jupyter-widgets/base", + "overflow": null, + "_model_module_version": "1.2.0", + "_view_count": null, + "flex_flow": null, + "width": null, + "min_width": null, + "border": null, + "align_items": null, + "bottom": null, + "_model_module": "@jupyter-widgets/base", + "top": null, + "grid_column": null, + "overflow_y": null, + "overflow_x": null, + "grid_auto_flow": null, + "grid_area": null, + "grid_template_columns": null, + "flex": null, + "_model_name": "LayoutModel", + "justify_items": null, + "grid_row": null, + "max_height": null, + "align_content": null, + "visibility": null, + "align_self": null, + "height": null, + "min_height": null, + "padding": null, + "grid_auto_rows": null, + "grid_gap": null, + "max_width": null, + "order": null, + "_view_module_version": "1.2.0", + "grid_template_areas": null, + "object_position": null, + "object_fit": null, + "grid_auto_columns": null, + "margin": null, + "display": null, + "left": null + } + }, + "a590054b6ff542189409d24050743b51": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "state": { + "_view_name": "StyleView", + "_model_name": "DescriptionStyleModel", + "description_width": "", + "_view_module": "@jupyter-widgets/base", + "_model_module_version": "1.5.0", + "_view_count": null, + "_view_module_version": "1.2.0", + "_model_module": "@jupyter-widgets/controls" + } + }, + "6f376e003f274abface1e0f063b06c63": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "state": { + "_view_name": "LayoutView", + "grid_template_rows": null, + "right": null, + "justify_content": null, + "_view_module": "@jupyter-widgets/base", + "overflow": null, + "_model_module_version": "1.2.0", + "_view_count": null, + "flex_flow": null, + "width": null, + "min_width": null, + "border": null, + "align_items": null, + "bottom": null, + "_model_module": "@jupyter-widgets/base", + "top": null, + "grid_column": null, + "overflow_y": null, + "overflow_x": null, + "grid_auto_flow": null, + "grid_area": null, + "grid_template_columns": null, + "flex": null, + "_model_name": "LayoutModel", + "justify_items": null, + "grid_row": null, + "max_height": null, + "align_content": null, + "visibility": null, + "align_self": null, + "height": null, + "min_height": null, + "padding": null, + "grid_auto_rows": null, + "grid_gap": null, + "max_width": null, + "order": null, + "_view_module_version": "1.2.0", + "grid_template_areas": null, + "object_position": null, + "object_fit": null, + "grid_auto_columns": null, + "margin": null, + "display": null, + "left": null + } + } + } + } + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "xvSGDbExff_I" + }, + "source": [ + "# **Homework 7 - Bert (Question Answering)**\n", + "\n", + "If you have any questions, feel free to email us at ntu-ml-2021spring-ta@googlegroups.com\n", + "\n", + "\n", + "\n", + "Slide: [Link](https://docs.google.com/presentation/d/1aQoWogAQo_xVJvMQMrGaYiWzuyfO0QyLLAhiMwFyS2w) Kaggle: [Link](https://www.kaggle.com/c/ml2021-spring-hw7) Data: [Link](https://drive.google.com/uc?id=1znKmX08v9Fygp-dgwo7BKiLIf2qL1FH1)\n", + "\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "WGOr_eS3wJJf" + }, + "source": [ + "## Task description\n", + "- Chinese Extractive Question Answering\n", + " - Input: Paragraph + Question\n", + " - Output: Answer\n", + "\n", + "- Objective: Learn how to fine tune a pretrained model on downstream task using transformers\n", + "\n", + "- Todo\n", + " - Fine tune a pretrained chinese BERT model\n", + " - Change hyperparameters (e.g. doc_stride)\n", + " - Apply linear learning rate decay\n", + " - Try other pretrained models\n", + " - Improve preprocessing\n", + " - Improve postprocessing\n", + "- Training tips\n", + " - Automatic mixed precision\n", + " - Gradient accumulation\n", + " - Ensemble\n", + "\n", + "- Estimated training time (tesla t4 with automatic mixed precision enabled)\n", + " - Simple: 8mins\n", + " - Medium: 8mins\n", + " - Strong: 25mins\n", + " - Boss: 2hrs\n", + " " + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "TJ1fSAJE2oaC" + }, + "source": [ + "## Download Dataset" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "ZJA_EEc3ZCH3", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "4f314101-0043-444b-ae03-ae1147b774b9" + }, + "source": [ + "!nvidia-smi" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Mon May 3 15:05:43 2021 \n", + "+-----------------------------------------------------------------------------+\n", + "| NVIDIA-SMI 465.19.01 Driver Version: 460.32.03 CUDA Version: 11.2 |\n", + "|-------------------------------+----------------------+----------------------+\n", + "| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n", + "| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n", + "| | | MIG M. |\n", + "|===============================+======================+======================|\n", + "| 0 Tesla T4 Off | 00000000:00:04.0 Off | 0 |\n", + "| N/A 45C P8 10W / 70W | 0MiB / 15109MiB | 0% Default |\n", + "| | | N/A |\n", + "+-------------------------------+----------------------+----------------------+\n", + " \n", + "+-----------------------------------------------------------------------------+\n", + "| Processes: |\n", + "| GPU GI CI PID Type Process name GPU Memory |\n", + "| ID ID Usage |\n", + "|=============================================================================|\n", + "| No running processes found |\n", + "+-----------------------------------------------------------------------------+\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "YPrc4Eie9Yo5", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "f8f4bfb9-0f18-49aa-b19b-9505ac4188f4" + }, + "source": [ + "# Download link 1\n", + "!gdown --id '1znKmX08v9Fygp-dgwo7BKiLIf2qL1FH1' --output hw7_data.zip\n", + "\n", + "# Download Link 2 (if the above link fails) \n", + "# !gdown --id '1pOu3FdPdvzielUZyggeD7KDnVy9iW1uC' --output hw7_data.zip\n", + "\n", + "!unzip -o hw7_data.zip\n", + "\n", + "# For this HW, K80 < P4 < T4 < P100 <= T4(fp16) < V100\n", + "!nvidia-smi" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Downloading...\n", + "From: https://drive.google.com/uc?id=1znKmX08v9Fygp-dgwo7BKiLIf2qL1FH1\n", + "To: /content/hw7_data.zip\n", + "\r0.00B [00:00, ?B/s]\r7.71MB [00:00, 67.3MB/s]\n", + "Archive: hw7_data.zip\n", + " inflating: hw7_dev.json \n", + " inflating: hw7_test.json \n", + " inflating: hw7_train.json \n", + "Mon May 3 15:05:50 2021 \n", + "+-----------------------------------------------------------------------------+\n", + "| NVIDIA-SMI 465.19.01 Driver Version: 460.32.03 CUDA Version: 11.2 |\n", + "|-------------------------------+----------------------+----------------------+\n", + "| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n", + "| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n", + "| | | MIG M. |\n", + "|===============================+======================+======================|\n", + "| 0 Tesla T4 Off | 00000000:00:04.0 Off | 0 |\n", + "| N/A 44C P8 10W / 70W | 0MiB / 15109MiB | 0% Default |\n", + "| | | N/A |\n", + "+-------------------------------+----------------------+----------------------+\n", + " \n", + "+-----------------------------------------------------------------------------+\n", + "| Processes: |\n", + "| GPU GI CI PID Type Process name GPU Memory |\n", + "| ID ID Usage |\n", + "|=============================================================================|\n", + "| No running processes found |\n", + "+-----------------------------------------------------------------------------+\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "TevOvhC03m0h" + }, + "source": [ + "## Install transformers\n", + "\n", + "Documentation for the toolkit: https://huggingface.co/transformers/" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "tbxWFX_jpDom", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "2f6484c5-153c-4b32-8852-ecde754a086d" + }, + "source": [ + "# You are allowed to change version of transformers or use other toolkits\n", + "!pip install transformers==4.5.0" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Collecting transformers==4.5.0\n", + "\u001b[?25l Downloading https://files.pythonhosted.org/packages/81/91/61d69d58a1af1bd81d9ca9d62c90a6de3ab80d77f27c5df65d9a2c1f5626/transformers-4.5.0-py3-none-any.whl (2.1MB)\n", + "\r\u001b[K |▏ | 10kB 22.4MB/s eta 0:00:01\r\u001b[K |▎ | 20kB 29.2MB/s eta 0:00:01\r\u001b[K |▌ | 30kB 22.1MB/s eta 0:00:01\r\u001b[K |▋ | 40kB 25.4MB/s eta 0:00:01\r\u001b[K |▊ | 51kB 19.4MB/s eta 0:00:01\r\u001b[K |█ | 61kB 17.8MB/s eta 0:00:01\r\u001b[K |█ | 71kB 16.8MB/s eta 0:00:01\r\u001b[K |█▏ | 81kB 18.0MB/s eta 0:00:01\r\u001b[K |█▍ | 92kB 17.3MB/s eta 0:00:01\r\u001b[K |█▌ | 102kB 17.6MB/s eta 0:00:01\r\u001b[K |█▊ | 112kB 17.6MB/s eta 0:00:01\r\u001b[K |█▉ | 122kB 17.6MB/s eta 0:00:01\r\u001b[K |██ | 133kB 17.6MB/s eta 0:00:01\r\u001b[K |██▏ | 143kB 17.6MB/s eta 0:00:01\r\u001b[K |██▎ | 153kB 17.6MB/s eta 0:00:01\r\u001b[K |██▍ | 163kB 17.6MB/s eta 0:00:01\r\u001b[K |██▋ | 174kB 17.6MB/s eta 0:00:01\r\u001b[K |██▊ | 184kB 17.6MB/s eta 0:00:01\r\u001b[K |███ | 194kB 17.6MB/s eta 0:00:01\r\u001b[K |███ | 204kB 17.6MB/s eta 0:00:01\r\u001b[K |███▏ | 215kB 17.6MB/s eta 0:00:01\r\u001b[K |███▍ | 225kB 17.6MB/s eta 0:00:01\r\u001b[K |███▌ | 235kB 17.6MB/s eta 0:00:01\r\u001b[K |███▋ | 245kB 17.6MB/s eta 0:00:01\r\u001b[K |███▉ | 256kB 17.6MB/s eta 0:00:01\r\u001b[K |████ | 266kB 17.6MB/s eta 0:00:01\r\u001b[K |████▏ | 276kB 17.6MB/s eta 0:00:01\r\u001b[K |████▎ | 286kB 17.6MB/s eta 0:00:01\r\u001b[K |████▍ | 296kB 17.6MB/s eta 0:00:01\r\u001b[K |████▋ | 307kB 17.6MB/s eta 0:00:01\r\u001b[K |████▊ | 317kB 17.6MB/s eta 0:00:01\r\u001b[K |████▉ | 327kB 17.6MB/s eta 0:00:01\r\u001b[K |█████ | 337kB 17.6MB/s eta 0:00:01\r\u001b[K |█████▏ | 348kB 17.6MB/s eta 0:00:01\r\u001b[K |█████▍ | 358kB 17.6MB/s eta 0:00:01\r\u001b[K |█████▌ | 368kB 17.6MB/s eta 0:00:01\r\u001b[K |█████▋ | 378kB 17.6MB/s eta 0:00:01\r\u001b[K |█████▉ | 389kB 17.6MB/s eta 0:00:01\r\u001b[K |██████ | 399kB 17.6MB/s eta 0:00:01\r\u001b[K |██████ | 409kB 17.6MB/s eta 0:00:01\r\u001b[K |██████▎ | 419kB 17.6MB/s eta 0:00:01\r\u001b[K |██████▍ | 430kB 17.6MB/s eta 0:00:01\r\u001b[K |██████▋ | 440kB 17.6MB/s eta 0:00:01\r\u001b[K |██████▊ | 450kB 17.6MB/s eta 0:00:01\r\u001b[K |██████▉ | 460kB 17.6MB/s eta 0:00:01\r\u001b[K |███████ | 471kB 17.6MB/s eta 0:00:01\r\u001b[K |███████▏ | 481kB 17.6MB/s eta 0:00:01\r\u001b[K |███████▎ | 491kB 17.6MB/s eta 0:00:01\r\u001b[K |███████▌ | 501kB 17.6MB/s eta 0:00:01\r\u001b[K |███████▋ | 512kB 17.6MB/s eta 0:00:01\r\u001b[K |███████▉ | 522kB 17.6MB/s eta 0:00:01\r\u001b[K |████████ | 532kB 17.6MB/s eta 0:00:01\r\u001b[K |████████ | 542kB 17.6MB/s eta 0:00:01\r\u001b[K |████████▎ | 552kB 17.6MB/s eta 0:00:01\r\u001b[K |████████▍ | 563kB 17.6MB/s eta 0:00:01\r\u001b[K |████████▌ | 573kB 17.6MB/s eta 0:00:01\r\u001b[K |████████▊ | 583kB 17.6MB/s eta 0:00:01\r\u001b[K |████████▉ | 593kB 17.6MB/s eta 0:00:01\r\u001b[K |█████████ | 604kB 17.6MB/s eta 0:00:01\r\u001b[K |█████████▏ | 614kB 17.6MB/s eta 0:00:01\r\u001b[K |█████████▎ | 624kB 17.6MB/s eta 0:00:01\r\u001b[K |█████████▌ | 634kB 17.6MB/s eta 0:00:01\r\u001b[K |█████████▋ | 645kB 17.6MB/s eta 0:00:01\r\u001b[K |█████████▊ | 655kB 17.6MB/s eta 0:00:01\r\u001b[K |██████████ | 665kB 17.6MB/s eta 0:00:01\r\u001b[K |██████████ | 675kB 17.6MB/s eta 0:00:01\r\u001b[K |██████████▎ | 686kB 17.6MB/s eta 0:00:01\r\u001b[K |██████████▍ | 696kB 17.6MB/s eta 0:00:01\r\u001b[K |██████████▌ | 706kB 17.6MB/s eta 0:00:01\r\u001b[K |██████████▊ | 716kB 17.6MB/s eta 0:00:01\r\u001b[K |██████████▉ | 727kB 17.6MB/s eta 0:00:01\r\u001b[K |███████████ | 737kB 17.6MB/s eta 0:00:01\r\u001b[K |███████████▏ | 747kB 17.6MB/s eta 0:00:01\r\u001b[K |███████████▎ | 757kB 17.6MB/s eta 0:00:01\r\u001b[K |███████████▌ | 768kB 17.6MB/s eta 0:00:01\r\u001b[K |███████████▋ | 778kB 17.6MB/s eta 0:00:01\r\u001b[K |███████████▊ | 788kB 17.6MB/s eta 0:00:01\r\u001b[K |████████████ | 798kB 17.6MB/s eta 0:00:01\r\u001b[K |████████████ | 808kB 17.6MB/s eta 0:00:01\r\u001b[K |████████████▏ | 819kB 17.6MB/s eta 0:00:01\r\u001b[K |████████████▍ | 829kB 17.6MB/s eta 0:00:01\r\u001b[K |████████████▌ | 839kB 17.6MB/s eta 0:00:01\r\u001b[K |████████████▊ | 849kB 17.6MB/s eta 0:00:01\r\u001b[K |████████████▉ | 860kB 17.6MB/s eta 0:00:01\r\u001b[K |█████████████ | 870kB 17.6MB/s eta 0:00:01\r\u001b[K |█████████████▏ | 880kB 17.6MB/s eta 0:00:01\r\u001b[K |█████████████▎ | 890kB 17.6MB/s eta 0:00:01\r\u001b[K |█████████████▍ | 901kB 17.6MB/s eta 0:00:01\r\u001b[K |█████████████▋ | 911kB 17.6MB/s eta 0:00:01\r\u001b[K |█████████████▊ | 921kB 17.6MB/s eta 0:00:01\r\u001b[K |██████████████ | 931kB 17.6MB/s eta 0:00:01\r\u001b[K |██████████████ | 942kB 17.6MB/s eta 0:00:01\r\u001b[K |██████████████▏ | 952kB 17.6MB/s eta 0:00:01\r\u001b[K |██████████████▍ | 962kB 17.6MB/s eta 0:00:01\r\u001b[K |██████████████▌ | 972kB 17.6MB/s eta 0:00:01\r\u001b[K |██████████████▋ | 983kB 17.6MB/s eta 0:00:01\r\u001b[K |██████████████▉ | 993kB 17.6MB/s eta 0:00:01\r\u001b[K |███████████████ | 1.0MB 17.6MB/s eta 0:00:01\r\u001b[K |███████████████▏ | 1.0MB 17.6MB/s eta 0:00:01\r\u001b[K |███████████████▎ | 1.0MB 17.6MB/s eta 0:00:01\r\u001b[K |███████████████▍ | 1.0MB 17.6MB/s eta 0:00:01\r\u001b[K |███████████████▋ | 1.0MB 17.6MB/s eta 0:00:01\r\u001b[K |███████████████▊ | 1.1MB 17.6MB/s eta 0:00:01\r\u001b[K |███████████████▉ | 1.1MB 17.6MB/s eta 0:00:01\r\u001b[K |████████████████ | 1.1MB 17.6MB/s eta 0:00:01\r\u001b[K |████████████████▏ | 1.1MB 17.6MB/s eta 0:00:01\r\u001b[K |████████████████▍ | 1.1MB 17.6MB/s eta 0:00:01\r\u001b[K |████████████████▌ | 1.1MB 17.6MB/s eta 0:00:01\r\u001b[K |████████████████▋ | 1.1MB 17.6MB/s eta 0:00:01\r\u001b[K |████████████████▉ | 1.1MB 17.6MB/s eta 0:00:01\r\u001b[K |█████████████████ | 1.1MB 17.6MB/s eta 0:00:01\r\u001b[K |█████████████████ | 1.1MB 17.6MB/s eta 0:00:01\r\u001b[K |█████████████████▎ | 1.2MB 17.6MB/s eta 0:00:01\r\u001b[K |█████████████████▍ | 1.2MB 17.6MB/s eta 0:00:01\r\u001b[K |█████████████████▌ | 1.2MB 17.6MB/s eta 0:00:01\r\u001b[K |█████████████████▊ | 1.2MB 17.6MB/s eta 0:00:01\r\u001b[K |█████████████████▉ | 1.2MB 17.6MB/s eta 0:00:01\r\u001b[K |██████████████████ | 1.2MB 17.6MB/s eta 0:00:01\r\u001b[K |██████████████████▏ | 1.2MB 17.6MB/s eta 0:00:01\r\u001b[K |██████████████████▎ | 1.2MB 17.6MB/s eta 0:00:01\r\u001b[K |██████████████████▌ | 1.2MB 17.6MB/s eta 0:00:01\r\u001b[K |██████████████████▋ | 1.2MB 17.6MB/s eta 0:00:01\r\u001b[K |██████████████████▊ | 1.3MB 17.6MB/s eta 0:00:01\r\u001b[K |███████████████████ | 1.3MB 17.6MB/s eta 0:00:01\r\u001b[K |███████████████████ | 1.3MB 17.6MB/s eta 0:00:01\r\u001b[K |███████████████████▎ | 1.3MB 17.6MB/s eta 0:00:01\r\u001b[K |███████████████████▍ | 1.3MB 17.6MB/s eta 0:00:01\r\u001b[K |███████████████████▌ | 1.3MB 17.6MB/s eta 0:00:01\r\u001b[K |███████████████████▊ | 1.3MB 17.6MB/s eta 0:00:01\r\u001b[K |███████████████████▉ | 1.3MB 17.6MB/s eta 0:00:01\r\u001b[K |████████████████████ | 1.3MB 17.6MB/s eta 0:00:01\r\u001b[K |████████████████████▏ | 1.4MB 17.6MB/s eta 0:00:01\r\u001b[K |████████████████████▎ | 1.4MB 17.6MB/s eta 0:00:01\r\u001b[K |████████████████████▌ | 1.4MB 17.6MB/s eta 0:00:01\r\u001b[K |████████████████████▋ | 1.4MB 17.6MB/s eta 0:00:01\r\u001b[K |████████████████████▊ | 1.4MB 17.6MB/s eta 0:00:01\r\u001b[K |█████████████████████ | 1.4MB 17.6MB/s eta 0:00:01\r\u001b[K |█████████████████████ | 1.4MB 17.6MB/s eta 0:00:01\r\u001b[K |█████████████████████▏ | 1.4MB 17.6MB/s eta 0:00:01\r\u001b[K |█████████████████████▍ | 1.4MB 17.6MB/s eta 0:00:01\r\u001b[K |█████████████████████▌ | 1.4MB 17.6MB/s eta 0:00:01\r\u001b[K |█████████████████████▊ | 1.5MB 17.6MB/s eta 0:00:01\r\u001b[K |█████████████████████▉ | 1.5MB 17.6MB/s eta 0:00:01\r\u001b[K |██████████████████████ | 1.5MB 17.6MB/s eta 0:00:01\r\u001b[K |██████████████████████▏ | 1.5MB 17.6MB/s eta 0:00:01\r\u001b[K |██████████████████████▎ | 1.5MB 17.6MB/s eta 0:00:01\r\u001b[K |██████████████████████▍ | 1.5MB 17.6MB/s eta 0:00:01\r\u001b[K |██████████████████████▋ | 1.5MB 17.6MB/s eta 0:00:01\r\u001b[K |██████████████████████▊ | 1.5MB 17.6MB/s eta 0:00:01\r\u001b[K |███████████████████████ | 1.5MB 17.6MB/s eta 0:00:01\r\u001b[K |███████████████████████ | 1.5MB 17.6MB/s eta 0:00:01\r\u001b[K |███████████████████████▏ | 1.6MB 17.6MB/s eta 0:00:01\r\u001b[K |███████████████████████▍ | 1.6MB 17.6MB/s eta 0:00:01\r\u001b[K |███████████████████████▌ | 1.6MB 17.6MB/s eta 0:00:01\r\u001b[K |███████████████████████▋ | 1.6MB 17.6MB/s eta 0:00:01\r\u001b[K |███████████████████████▉ | 1.6MB 17.6MB/s eta 0:00:01\r\u001b[K |████████████████████████ | 1.6MB 17.6MB/s eta 0:00:01\r\u001b[K |████████████████████████▏ | 1.6MB 17.6MB/s eta 0:00:01\r\u001b[K |████████████████████████▎ | 1.6MB 17.6MB/s eta 0:00:01\r\u001b[K |████████████████████████▍ | 1.6MB 17.6MB/s eta 0:00:01\r\u001b[K |████████████████████████▋ | 1.6MB 17.6MB/s eta 0:00:01\r\u001b[K |████████████████████████▊ | 1.7MB 17.6MB/s eta 0:00:01\r\u001b[K |████████████████████████▉ | 1.7MB 17.6MB/s eta 0:00:01\r\u001b[K |█████████████████████████ | 1.7MB 17.6MB/s eta 0:00:01\r\u001b[K |█████████████████████████▏ | 1.7MB 17.6MB/s eta 0:00:01\r\u001b[K |█████████████████████████▍ | 1.7MB 17.6MB/s eta 0:00:01\r\u001b[K |█████████████████████████▌ | 1.7MB 17.6MB/s eta 0:00:01\r\u001b[K |█████████████████████████▋ | 1.7MB 17.6MB/s eta 0:00:01\r\u001b[K |█████████████████████████▉ | 1.7MB 17.6MB/s eta 0:00:01\r\u001b[K |██████████████████████████ | 1.7MB 17.6MB/s eta 0:00:01\r\u001b[K |██████████████████████████ | 1.8MB 17.6MB/s eta 0:00:01\r\u001b[K |██████████████████████████▎ | 1.8MB 17.6MB/s eta 0:00:01\r\u001b[K |██████████████████████████▍ | 1.8MB 17.6MB/s eta 0:00:01\r\u001b[K |██████████████████████████▋ | 1.8MB 17.6MB/s eta 0:00:01\r\u001b[K |██████████████████████████▊ | 1.8MB 17.6MB/s eta 0:00:01\r\u001b[K |██████████████████████████▉ | 1.8MB 17.6MB/s eta 0:00:01\r\u001b[K |███████████████████████████ | 1.8MB 17.6MB/s eta 0:00:01\r\u001b[K |███████████████████████████▏ | 1.8MB 17.6MB/s eta 0:00:01\r\u001b[K |███████████████████████████▎ | 1.8MB 17.6MB/s eta 0:00:01\r\u001b[K |███████████████████████████▌ | 1.8MB 17.6MB/s eta 0:00:01\r\u001b[K |███████████████████████████▋ | 1.9MB 17.6MB/s eta 0:00:01\r\u001b[K |███████████████████████████▉ | 1.9MB 17.6MB/s eta 0:00:01\r\u001b[K |████████████████████████████ | 1.9MB 17.6MB/s eta 0:00:01\r\u001b[K |████████████████████████████ | 1.9MB 17.6MB/s eta 0:00:01\r\u001b[K |████████████████████████████▎ | 1.9MB 17.6MB/s eta 0:00:01\r\u001b[K |████████████████████████████▍ | 1.9MB 17.6MB/s eta 0:00:01\r\u001b[K |████████████████████████████▌ | 1.9MB 17.6MB/s eta 0:00:01\r\u001b[K |████████████████████████████▊ | 1.9MB 17.6MB/s eta 0:00:01\r\u001b[K |████████████████████████████▉ | 1.9MB 17.6MB/s eta 0:00:01\r\u001b[K |█████████████████████████████ | 1.9MB 17.6MB/s eta 0:00:01\r\u001b[K |█████████████████████████████▏ | 2.0MB 17.6MB/s eta 0:00:01\r\u001b[K |█████████████████████████████▎ | 2.0MB 17.6MB/s eta 0:00:01\r\u001b[K |█████████████████████████████▌ | 2.0MB 17.6MB/s eta 0:00:01\r\u001b[K |█████████████████████████████▋ | 2.0MB 17.6MB/s eta 0:00:01\r\u001b[K |█████████████████████████████▊ | 2.0MB 17.6MB/s eta 0:00:01\r\u001b[K |██████████████████████████████ | 2.0MB 17.6MB/s eta 0:00:01\r\u001b[K |██████████████████████████████ | 2.0MB 17.6MB/s eta 0:00:01\r\u001b[K |██████████████████████████████▎ | 2.0MB 17.6MB/s eta 0:00:01\r\u001b[K |██████████████████████████████▍ | 2.0MB 17.6MB/s eta 0:00:01\r\u001b[K |██████████████████████████████▌ | 2.0MB 17.6MB/s eta 0:00:01\r\u001b[K |██████████████████████████████▊ | 2.1MB 17.6MB/s eta 0:00:01\r\u001b[K |██████████████████████████████▉ | 2.1MB 17.6MB/s eta 0:00:01\r\u001b[K |███████████████████████████████ | 2.1MB 17.6MB/s eta 0:00:01\r\u001b[K |███████████████████████████████▏| 2.1MB 17.6MB/s eta 0:00:01\r\u001b[K |███████████████████████████████▎| 2.1MB 17.6MB/s eta 0:00:01\r\u001b[K |███████████████████████████████▌| 2.1MB 17.6MB/s eta 0:00:01\r\u001b[K |███████████████████████████████▋| 2.1MB 17.6MB/s eta 0:00:01\r\u001b[K |███████████████████████████████▊| 2.1MB 17.6MB/s eta 0:00:01\r\u001b[K |████████████████████████████████| 2.1MB 17.6MB/s eta 0:00:01\r\u001b[K |████████████████████████████████| 2.2MB 17.6MB/s \n", + "\u001b[?25hRequirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.7/dist-packages (from transformers==4.5.0) (2019.12.20)\n", + "Requirement already satisfied: importlib-metadata; python_version < \"3.8\" in /usr/local/lib/python3.7/dist-packages (from transformers==4.5.0) (3.10.1)\n", + "Requirement already satisfied: packaging in /usr/local/lib/python3.7/dist-packages (from transformers==4.5.0) (20.9)\n", + "Collecting tokenizers<0.11,>=0.10.1\n", + "\u001b[?25l Downloading https://files.pythonhosted.org/packages/ae/04/5b870f26a858552025a62f1649c20d29d2672c02ff3c3fb4c688ca46467a/tokenizers-0.10.2-cp37-cp37m-manylinux2010_x86_64.whl (3.3MB)\n", + "\u001b[K |████████████████████████████████| 3.3MB 54.4MB/s \n", + "\u001b[?25hRequirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.7/dist-packages (from transformers==4.5.0) (4.41.1)\n", + "Requirement already satisfied: filelock in /usr/local/lib/python3.7/dist-packages (from transformers==4.5.0) (3.0.12)\n", + "Collecting sacremoses\n", + "\u001b[?25l Downloading https://files.pythonhosted.org/packages/75/ee/67241dc87f266093c533a2d4d3d69438e57d7a90abb216fa076e7d475d4a/sacremoses-0.0.45-py3-none-any.whl (895kB)\n", + "\u001b[K |████████████████████████████████| 901kB 50.8MB/s \n", + "\u001b[?25hRequirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.7/dist-packages (from transformers==4.5.0) (1.19.5)\n", + "Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from transformers==4.5.0) (2.23.0)\n", + "Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata; python_version < \"3.8\"->transformers==4.5.0) (3.4.1)\n", + "Requirement already satisfied: typing-extensions>=3.6.4; python_version < \"3.8\" in /usr/local/lib/python3.7/dist-packages (from importlib-metadata; python_version < \"3.8\"->transformers==4.5.0) (3.7.4.3)\n", + "Requirement already satisfied: pyparsing>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging->transformers==4.5.0) (2.4.7)\n", + "Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from sacremoses->transformers==4.5.0) (1.15.0)\n", + "Requirement already satisfied: click in /usr/local/lib/python3.7/dist-packages (from sacremoses->transformers==4.5.0) (7.1.2)\n", + "Requirement already satisfied: joblib in /usr/local/lib/python3.7/dist-packages (from sacremoses->transformers==4.5.0) (1.0.1)\n", + "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->transformers==4.5.0) (3.0.4)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->transformers==4.5.0) (2020.12.5)\n", + "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->transformers==4.5.0) (1.24.3)\n", + "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->transformers==4.5.0) (2.10)\n", + "Installing collected packages: tokenizers, sacremoses, transformers\n", + "Successfully installed sacremoses-0.0.45 tokenizers-0.10.2 transformers-4.5.0\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "8dKM4yCh4LI_" + }, + "source": [ + "## Import Packages" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "WOTHHtWJoahe" + }, + "source": [ + "import json\n", + "import numpy as np\n", + "import random\n", + "import torch\n", + "from torch.utils.data import DataLoader, Dataset \n", + "from transformers import AdamW, BertForQuestionAnswering, BertTokenizerFast\n", + "\n", + "from tqdm.auto import tqdm\n", + "\n", + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "\n", + "# Fix random seed for reproducibility\n", + "def same_seeds(seed):\n", + "\t torch.manual_seed(seed)\n", + "\t if torch.cuda.is_available():\n", + "\t\t torch.cuda.manual_seed(seed)\n", + "\t\t torch.cuda.manual_seed_all(seed)\n", + "\t np.random.seed(seed)\n", + "\t random.seed(seed)\n", + "\t torch.backends.cudnn.benchmark = False\n", + "\t torch.backends.cudnn.deterministic = True\n", + "same_seeds(0)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "7pBtSZP1SKQO", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "b55db708-681a-4a75-afdd-23f7e6b844d4" + }, + "source": [ + "# Change \"fp16_training\" to True to support automatic mixed precision training (fp16)\t\n", + "fp16_training = True\n", + "\n", + "if fp16_training:\n", + " !pip install accelerate==0.2.0\n", + " from accelerate import Accelerator\n", + " accelerator = Accelerator(fp16=True)\n", + " device = accelerator.device\n", + "\n", + "# Documentation for the toolkit: https://huggingface.co/docs/accelerate/" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Collecting accelerate==0.2.0\n", + "\u001b[?25l Downloading https://files.pythonhosted.org/packages/60/c6/6f08def78c19e328335236ec283a7c70e73913d1ed6f653ce2101bfad139/accelerate-0.2.0-py3-none-any.whl (47kB)\n", + "\r\u001b[K |███████ | 10kB 22.1MB/s eta 0:00:01\r\u001b[K |█████████████▉ | 20kB 20.1MB/s eta 0:00:01\r\u001b[K |████████████████████▉ | 30kB 20.7MB/s eta 0:00:01\r\u001b[K |███████████████████████████▊ | 40kB 21.7MB/s eta 0:00:01\r\u001b[K |████████████████████████████████| 51kB 7.4MB/s \n", + "\u001b[?25hRequirement already satisfied: torch>=1.4.0 in /usr/local/lib/python3.7/dist-packages (from accelerate==0.2.0) (1.8.1+cu101)\n", + "Collecting pyaml>=20.4.0\n", + " Downloading https://files.pythonhosted.org/packages/15/c4/1310a054d33abc318426a956e7d6df0df76a6ddfa9c66f6310274fb75d42/pyaml-20.4.0-py2.py3-none-any.whl\n", + "Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from torch>=1.4.0->accelerate==0.2.0) (1.19.5)\n", + "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from torch>=1.4.0->accelerate==0.2.0) (3.7.4.3)\n", + "Requirement already satisfied: PyYAML in /usr/local/lib/python3.7/dist-packages (from pyaml>=20.4.0->accelerate==0.2.0) (3.13)\n", + "Installing collected packages: pyaml, accelerate\n", + "Successfully installed accelerate-0.2.0 pyaml-20.4.0\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "2YgXHuVLp_6j" + }, + "source": [ + "## Load Model and Tokenizer\n", + "\n", + "\n", + "\n", + "\n", + " " + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "xyBCYGjAp3ym", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "09d5f0c7-4306-473c-dd65-f52ebfea8230" + }, + "source": [ + "# model = BertForQuestionAnswering.from_pretrained(\"bert-base-chinese\").to(device)\n", + "# tokenizer = BertTokenizerFast.from_pretrained(\"bert-base-chinese\")\n", + "\n", + "from transformers import AutoTokenizer, AutoModelForQuestionAnswering\n", + " \n", + "tokenizer = AutoTokenizer.from_pretrained(\"wptoux/albert-chinese-large-qa\")\n", + "\n", + "model = AutoModelForQuestionAnswering.from_pretrained(\"wptoux/albert-chinese-large-qa\")\n", + "\n", + "# You can safely ignore the warning message (it pops up because new prediction heads for QA are initialized randomly)" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Some weights of the model checkpoint at bert-base-chinese were not used when initializing BertForQuestionAnswering: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']\n", + "- This IS expected if you are initializing BertForQuestionAnswering from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", + "- This IS NOT expected if you are initializing BertForQuestionAnswering from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", + "Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at bert-base-chinese and are newly initialized: ['qa_outputs.weight', 'qa_outputs.bias']\n", + "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" + ], + "name": "stderr" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "3Td-GTmk5OW4" + }, + "source": [ + "## Read Data\n", + "\n", + "- Training set: 26935 QA pairs\n", + "- Dev set: 3523 QA pairs\n", + "- Test set: 3492 QA pairs\n", + "\n", + "- {train/dev/test}_questions:\t\n", + " - List of dicts with the following keys:\n", + " - id (int)\n", + " - paragraph_id (int)\n", + " - question_text (string)\n", + " - answer_text (string)\n", + " - answer_start (int)\n", + " - answer_end (int)\n", + "- {train/dev/test}_paragraphs: \n", + " - List of strings\n", + " - paragraph_ids in questions correspond to indexs in paragraphs\n", + " - A paragraph may be used by several questions " + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "NvX7hlepogvu" + }, + "source": [ + "def read_data(file):\n", + " with open(file, 'r', encoding=\"utf-8\") as reader:\n", + " data = json.load(reader)\n", + " return data[\"questions\"], data[\"paragraphs\"]\n", + "\n", + "train_questions, train_paragraphs = read_data(\"hw7_train.json\")\n", + "dev_questions, dev_paragraphs = read_data(\"hw7_dev.json\")\n", + "test_questions, test_paragraphs = read_data(\"hw7_test.json\")" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Fm0rpTHq0e4N" + }, + "source": [ + "## Tokenize Data" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "rTZ6B70Hoxie", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "002a8d08-aa0d-4fcf-f738-b6bbd121da53" + }, + "source": [ + "# Tokenize questions and paragraphs separately\n", + "# 「add_special_tokens」 is set to False since special tokens will be added when tokenized questions and paragraphs are combined in datset __getitem__ \n", + "\n", + "train_questions_tokenized = tokenizer([train_question[\"question_text\"] for train_question in train_questions], add_special_tokens=False)\n", + "dev_questions_tokenized = tokenizer([dev_question[\"question_text\"] for dev_question in dev_questions], add_special_tokens=False)\n", + "test_questions_tokenized = tokenizer([test_question[\"question_text\"] for test_question in test_questions], add_special_tokens=False) \n", + "\n", + "train_paragraphs_tokenized = tokenizer(train_paragraphs, add_special_tokens=False)\n", + "dev_paragraphs_tokenized = tokenizer(dev_paragraphs, add_special_tokens=False)\n", + "test_paragraphs_tokenized = tokenizer(test_paragraphs, add_special_tokens=False)\n", + "\n", + "# You can safely ignore the warning message as tokenized sequences will be futher processed in datset __getitem__ before passing to model" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Token indices sequence length is longer than the specified maximum sequence length for this model (570 > 512). Running this sequence through the model will result in indexing errors\n" + ], + "name": "stderr" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Ws8c8_4d5UCI" + }, + "source": [ + "## Dataset and Dataloader" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "Xjooag-Swnuh" + }, + "source": [ + "class QA_Dataset(Dataset):\n", + " def __init__(self, split, questions, tokenized_questions, tokenized_paragraphs):\n", + " self.split = split\n", + " self.questions = questions\n", + " self.tokenized_questions = tokenized_questions\n", + " self.tokenized_paragraphs = tokenized_paragraphs\n", + " self.max_question_len = 40\n", + " self.max_paragraph_len = 384\n", + " \n", + " ##### TODO: Change value of doc_stride #####\n", + " self.doc_stride = 128\n", + "\n", + " # Input sequence length = [CLS] + question + [SEP] + paragraph + [SEP]\n", + " self.max_seq_len = 1 + self.max_question_len + 1 + self.max_paragraph_len + 1\n", + "\n", + " def __len__(self):\n", + " return len(self.questions)\n", + "\n", + " def __getitem__(self, idx):\n", + " question = self.questions[idx]\n", + " tokenized_question = self.tokenized_questions[idx]\n", + " tokenized_paragraph = self.tokenized_paragraphs[question[\"paragraph_id\"]]\n", + "\n", + " ##### TODO: Preprocessing #####\n", + " # Hint: How to prevent model from learning something it should not learn\n", + "\n", + " if self.split == \"train\":\n", + " # Convert answer's start/end positions in paragraph_text to start/end positions in tokenized_paragraph \n", + " answer_start_token = tokenized_paragraph.char_to_token(question[\"answer_start\"])\n", + " answer_end_token = tokenized_paragraph.char_to_token(question[\"answer_end\"])\n", + "\n", + " # A single window is obtained by slicing the portion of paragraph containing the answer\n", + " # mid = (answer_start_token + answer_end_token) // 2\n", + " # mid = int(torch.randint(answer_start_token,answer_end_token+1,(1,)))\n", + " mid = ((answer_start_token + answer_end_token) // 2) + int(torch.randint(-50,50,(1,)))\n", + " paragraph_start = max(0, min(mid - self.max_paragraph_len // 2, len(tokenized_paragraph) - self.max_paragraph_len))\n", + " paragraph_end = paragraph_start + self.max_paragraph_len\n", + " \n", + " # Slice question/paragraph and add special tokens (101: CLS, 102: SEP)\n", + " input_ids_question = [101] + tokenized_question.ids[:self.max_question_len] + [102] \n", + " input_ids_paragraph = tokenized_paragraph.ids[paragraph_start : paragraph_end] + [102]\t\t\n", + " \n", + " # Convert answer's start/end positions in tokenized_paragraph to start/end positions in the window \n", + " answer_start_token += len(input_ids_question) - paragraph_start\n", + " answer_end_token += len(input_ids_question) - paragraph_start\n", + " \n", + " # Pad sequence and obtain inputs to model \n", + " input_ids, token_type_ids, attention_mask = self.padding(input_ids_question, input_ids_paragraph)\n", + " return torch.tensor(input_ids), torch.tensor(token_type_ids), torch.tensor(attention_mask), answer_start_token, answer_end_token\n", + "\n", + " # Validation/Testing\n", + " else:\n", + " input_ids_list, token_type_ids_list, attention_mask_list = [], [], []\n", + " \n", + " # Paragraph is split into several windows, each with start positions separated by step \"doc_stride\"\n", + " for i in range(0, len(tokenized_paragraph), self.doc_stride):\n", + " \n", + " # Slice question/paragraph and add special tokens (101: CLS, 102: SEP)\n", + " input_ids_question = [101] + tokenized_question.ids[:self.max_question_len] + [102]\n", + " input_ids_paragraph = tokenized_paragraph.ids[i : i + self.max_paragraph_len] + [102]\n", + " \n", + " # Pad sequence and obtain inputs to model\n", + " input_ids, token_type_ids, attention_mask = self.padding(input_ids_question, input_ids_paragraph)\n", + " \n", + " input_ids_list.append(input_ids)\n", + " token_type_ids_list.append(token_type_ids)\n", + " attention_mask_list.append(attention_mask)\n", + " \n", + " return torch.tensor(input_ids_list), torch.tensor(token_type_ids_list), torch.tensor(attention_mask_list)\n", + "\n", + " def padding(self, input_ids_question, input_ids_paragraph):\n", + " # Pad zeros if sequence length is shorter than max_seq_len\n", + " padding_len = self.max_seq_len - len(input_ids_question) - len(input_ids_paragraph)\n", + " # Indices of input sequence tokens in the vocabulary\n", + " input_ids = input_ids_question + input_ids_paragraph + [0] * padding_len\n", + " # Segment token indices to indicate first and second portions of the inputs. Indices are selected in [0, 1]\n", + " token_type_ids = [0] * len(input_ids_question) + [1] * len(input_ids_paragraph) + [0] * padding_len\n", + " # Mask to avoid performing attention on padding token indices. Mask values selected in [0, 1]\n", + " attention_mask = [1] * (len(input_ids_question) + len(input_ids_paragraph)) + [0] * padding_len\n", + " \n", + " return input_ids, token_type_ids, attention_mask\n", + "\n", + "train_set = QA_Dataset(\"train\", train_questions, train_questions_tokenized, train_paragraphs_tokenized)\n", + "dev_set = QA_Dataset(\"dev\", dev_questions, dev_questions_tokenized, dev_paragraphs_tokenized)\n", + "test_set = QA_Dataset(\"test\", test_questions, test_questions_tokenized, test_paragraphs_tokenized)\n", + "\n", + "train_batch_size = 8\n", + "\n", + "# Note: Do NOT change batch size of dev_loader / test_loader !\n", + "# Although batch size=1, it is actually a batch consisting of several windows from the same QA pair\n", + "train_loader = DataLoader(train_set, batch_size=train_batch_size, shuffle=True, pin_memory=True)\n", + "dev_loader = DataLoader(dev_set, batch_size=1, shuffle=False, pin_memory=True)\n", + "test_loader = DataLoader(test_set, batch_size=1, shuffle=False, pin_memory=True)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "5_H1kqhR8CdM" + }, + "source": [ + "## Function for Evaluation" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "SqeA3PLPxOHu" + }, + "source": [ + "def evaluate(data, output):\n", + " ##### TODO: Postprocessing #####\n", + " # There is a bug and room for improvement in postprocessing \n", + " # Hint: Open your prediction file to see what is wrong \n", + " \n", + " answer = ''\n", + " max_prob = float('-inf')\n", + " num_of_windows = data[0].shape[1]\n", + " \n", + " for k in range(num_of_windows):\n", + " # Obtain answer by choosing the most probable start position / end position\n", + " start_prob, start_index = torch.max(output.start_logits[k], dim=0)\n", + " end_prob, end_index = torch.max(output.end_logits[k], dim=0)\n", + " \n", + " # Probability of answer is calculated as sum of start_prob and end_prob\n", + " prob = start_prob + end_prob\n", + " \n", + " # Replace answer if calculated probability is larger than previous windows\n", + " if prob > max_prob:\n", + " max_prob = prob\n", + " # Convert tokens to chars (e.g. [1920, 7032] --> \"大 金\")\n", + " answer = tokenizer.decode(data[0][0][k][start_index : end_index + 1])\n", + " \n", + " # Remove spaces in answer (e.g. \"大 金\" --> \"大金\")\n", + " return answer.replace(' ','')" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "rzHQit6eMnKG" + }, + "source": [ + "## Training" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "3Q-B6ka7xoCM", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 475, + "referenced_widgets": [ + "2ed6ef4de01843598855909b4209480e", + "f2fd2cfef1714f4ca3c4a19d4f4c8029", + "1092a4b703aa429bb04b902424e88667", + "e31d3c7a5edc4ae49f3f2ee1aad98858", + "cabe2e3a668f49959005b8a8f1600e4c", + "847745386d86408a8ca243418a08c2a2", + "a590054b6ff542189409d24050743b51", + "6f376e003f274abface1e0f063b06c63" + ] + }, + "outputId": "ae5d08ab-c771-426e-a234-988b2521e0c5" + }, + "source": [ + "num_epoch = 2 \n", + "validation = True\n", + "logging_step = 100\n", + "learning_rate = 1e-5\n", + "optimizer = AdamW(model.parameters(), lr=learning_rate)\n", + "if fp16_training:\n", + " model, optimizer, train_loader = accelerator.prepare(model, optimizer, train_loader) \n", + "\n", + "\n", + "\n", + "model.train()\n", + "\n", + "print(\"Start Training ...\")\n", + "\n", + "best_acc = 0\n", + "\n", + "for epoch in range(num_epoch):\n", + " step = 1\n", + " train_loss = train_acc = 0\n", + " \n", + " for data in tqdm(train_loader):\t\n", + " # Load all data into GPU\n", + " data = [i.to(device) for i in data]\n", + " \n", + " # Model inputs: input_ids, token_type_ids, attention_mask, start_positions, end_positions (Note: only \"input_ids\" is mandatory)\n", + " # Model outputs: start_logits, end_logits, loss (return when start_positions/end_positions are provided) \n", + " output = model(input_ids=data[0], token_type_ids=data[1], attention_mask=data[2], start_positions=data[3], end_positions=data[4])\n", + "\n", + " # Choose the most probable start position / end position\n", + " start_index = torch.argmax(output.start_logits, dim=1)\n", + " end_index = torch.argmax(output.end_logits, dim=1)\n", + " \n", + " # Prediction is correct only if both start_index and end_index are correct\n", + " train_acc += ((start_index == data[3]) & (end_index == data[4])).float().mean()\n", + " train_loss += output.loss\n", + " \n", + " if fp16_training:\n", + " accelerator.backward(output.loss)\n", + " else:\n", + " output.loss.backward()\n", + " \n", + " optimizer.step()\n", + " optimizer.zero_grad()\n", + " step += 1\n", + "\n", + " ##### TODO: Apply linear learning rate decay #####\n", + " optimizer.param_groups[0][\"lr\"] -= learning_rate / (1685*num_epoch)\n", + "\n", + " # Print training loss and accuracy over past logging step\n", + " if step % logging_step == 0:\n", + " print(f\"Epoch {epoch + 1} | Step {step} | loss = {train_loss.item() / logging_step:.3f}, acc = {train_acc / logging_step:.3f}\")\n", + " train_loss = train_acc = 0\n", + "\n", + " if validation:\n", + " print(\"Evaluating Dev Set ...\")\n", + " model.eval()\n", + " with torch.no_grad():\n", + " dev_acc = 0\n", + " for i, data in enumerate(tqdm(dev_loader)):\n", + " output = model(input_ids=data[0].squeeze(dim=0).to(device), token_type_ids=data[1].squeeze(dim=0).to(device),\n", + " attention_mask=data[2].squeeze(dim=0).to(device))\n", + " # prediction is correct only if answer text exactly matches\n", + " dev_acc += evaluate(data, output) == dev_questions[i][\"answer_text\"]\n", + " print(f\"Validation | Epoch {epoch + 1} | acc = {dev_acc / len(dev_loader):.3f}\")\n", + " model.train()\n", + "\n", + "# Save a model and its configuration file to the directory 「saved_model」 \n", + "# i.e. there are two files under the direcory 「saved_model」: 「pytorch_model.bin」 and 「config.json」\n", + "# Saved model can be re-loaded using 「model = BertForQuestionAnswering.from_pretrained(\"saved_model\")」\n", + "print(\"Saving Model ...\")\n", + "model_save_dir = \"saved_model\" \n", + "model.save_pretrained(model_save_dir)" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Start Training ...\n" + ], + "name": "stdout" + }, + { + "output_type": "display_data", + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "2ed6ef4de01843598855909b4209480e", + "version_minor": 0, + "version_major": 2 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, max=1684.0), HTML(value='')))" + ] + }, + "metadata": { + "tags": [] + } + }, + { + "output_type": "error", + "ename": "RuntimeError", + "evalue": "ignored", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 25\u001b[0m \u001b[0;31m# Model inputs: input_ids, token_type_ids, attention_mask, start_positions, end_positions (Note: only \"input_ids\" is mandatory)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 26\u001b[0m \u001b[0;31m# Model outputs: start_logits, end_logits, loss (return when start_positions/end_positions are provided)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 27\u001b[0;31m \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput_ids\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtoken_type_ids\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mattention_mask\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstart_positions\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mend_positions\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m4\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 28\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 29\u001b[0m \u001b[0;31m# Choose the most probable start position / end position\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 887\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 888\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 889\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 890\u001b[0m for hook in itertools.chain(\n\u001b[1;32m 891\u001b[0m \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/torch/cuda/amp/autocast_mode.py\u001b[0m in \u001b[0;36mdecorate_autocast\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 137\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mdecorate_autocast\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 138\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 139\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 140\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mdecorate_autocast\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 141\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/transformers/models/bert/modeling_bert.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, start_positions, end_positions, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m 1782\u001b[0m \u001b[0moutput_attentions\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0moutput_attentions\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1783\u001b[0m \u001b[0moutput_hidden_states\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0moutput_hidden_states\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1784\u001b[0;31m \u001b[0mreturn_dict\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mreturn_dict\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1785\u001b[0m )\n\u001b[1;32m 1786\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 887\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 888\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 889\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 890\u001b[0m for hook in itertools.chain(\n\u001b[1;32m 891\u001b[0m \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/transformers/models/bert/modeling_bert.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, encoder_hidden_states, encoder_attention_mask, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m 979\u001b[0m \u001b[0moutput_attentions\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0moutput_attentions\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 980\u001b[0m \u001b[0moutput_hidden_states\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0moutput_hidden_states\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 981\u001b[0;31m \u001b[0mreturn_dict\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mreturn_dict\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 982\u001b[0m )\n\u001b[1;32m 983\u001b[0m \u001b[0msequence_output\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mencoder_outputs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 887\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 888\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 889\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 890\u001b[0m for hook in itertools.chain(\n\u001b[1;32m 891\u001b[0m \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/transformers/models/bert/modeling_bert.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m 573\u001b[0m \u001b[0mencoder_attention_mask\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 574\u001b[0m \u001b[0mpast_key_value\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 575\u001b[0;31m \u001b[0moutput_attentions\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 576\u001b[0m )\n\u001b[1;32m 577\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 887\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 888\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 889\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 890\u001b[0m for hook in itertools.chain(\n\u001b[1;32m 891\u001b[0m \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/transformers/models/bert/modeling_bert.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, past_key_value, output_attentions)\u001b[0m\n\u001b[1;32m 459\u001b[0m \u001b[0mhead_mask\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 460\u001b[0m \u001b[0moutput_attentions\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0moutput_attentions\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 461\u001b[0;31m \u001b[0mpast_key_value\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself_attn_past_key_value\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 462\u001b[0m )\n\u001b[1;32m 463\u001b[0m \u001b[0mattention_output\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself_attention_outputs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 887\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 888\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 889\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 890\u001b[0m for hook in itertools.chain(\n\u001b[1;32m 891\u001b[0m \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/transformers/models/bert/modeling_bert.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, past_key_value, output_attentions)\u001b[0m\n\u001b[1;32m 392\u001b[0m \u001b[0mencoder_attention_mask\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 393\u001b[0m \u001b[0mpast_key_value\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 394\u001b[0;31m \u001b[0moutput_attentions\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 395\u001b[0m )\n\u001b[1;32m 396\u001b[0m \u001b[0mattention_output\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moutput\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself_outputs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhidden_states\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 887\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 888\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 889\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 890\u001b[0m for hook in itertools.chain(\n\u001b[1;32m 891\u001b[0m \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/transformers/models/bert/modeling_bert.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, past_key_value, output_attentions)\u001b[0m\n\u001b[1;32m 317\u001b[0m \u001b[0;31m# This is actually dropping out entire tokens to attend to, which might\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 318\u001b[0m \u001b[0;31m# seem a bit unusual, but is taken from the original Transformer paper.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 319\u001b[0;31m \u001b[0mattention_probs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdropout\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mattention_probs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 320\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 321\u001b[0m \u001b[0;31m# Mask heads if we want to\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 887\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 888\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 889\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 890\u001b[0m for hook in itertools.chain(\n\u001b[1;32m 891\u001b[0m \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/torch/nn/modules/dropout.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 56\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 57\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 58\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdropout\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mp\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtraining\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minplace\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 59\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 60\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py\u001b[0m in \u001b[0;36mdropout\u001b[0;34m(input, p, training, inplace)\u001b[0m\n\u001b[1;32m 1074\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mp\u001b[0m \u001b[0;34m<\u001b[0m \u001b[0;36m0.0\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mp\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0;36m1.0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1075\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mValueError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"dropout probability has to be between 0 and 1, \"\u001b[0m \u001b[0;34m\"but got {}\"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mp\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1076\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0m_VF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdropout_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mp\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtraining\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0minplace\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0m_VF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdropout\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mp\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtraining\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1077\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1078\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mRuntimeError\u001b[0m: CUDA out of memory. Tried to allocate 28.00 MiB (GPU 0; 14.76 GiB total capacity; 13.23 GiB already allocated; 21.75 MiB free; 13.71 GiB reserved in total by PyTorch)" + ] + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "kMmdLOKBMsdE" + }, + "source": [ + "## Testing" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "U5scNKC9xz0C" + }, + "source": [ + "print(\"Evaluating Test Set ...\")\n", + "\n", + "result = []\n", + "\n", + "model.eval()\n", + "with torch.no_grad():\n", + " for data in tqdm(test_loader):\n", + " output = model(input_ids=data[0].squeeze(dim=0).to(device), token_type_ids=data[1].squeeze(dim=0).to(device),\n", + " attention_mask=data[2].squeeze(dim=0).to(device))\n", + " result.append(evaluate(data, output))\n", + "\n", + "result_file = \"result.csv\"\n", + "with open(result_file, 'w') as f:\t\n", + "\t f.write(\"ID,Answer\\n\")\n", + "\t for i, test_question in enumerate(test_questions):\n", + " # Replace commas in answers with empty strings (since csv is separated by comma)\n", + " # Answers in kaggle are processed in the same way\n", + "\t\t f.write(f\"{test_question['id']},{result[i].replace(',','')}\\n\")\n", + "\n", + "print(f\"Completed! Result is in {result_file}\")" + ], + "execution_count": null, + "outputs": [] + } + ] +} \ No newline at end of file