From 262fb4bfec6ed6bcc14f50ac631f4806a3ed35d8 Mon Sep 17 00:00:00 2001 From: Sherry-XLL Date: Mon, 4 Apr 2022 04:41:45 +0000 Subject: [PATCH 1/5] FEA: add tutorials of prediction --- ...ith-item-infor-fix-missing-last-item.ipynb | 3449 +++++++++++++++++ ...cbole-using-all-items-for-prediction.ipynb | 1969 ++++++++++ ...ential-model-fixed-missing-last-item.ipynb | 2865 ++++++++++++++ 3 files changed, 8283 insertions(+) create mode 100644 run_example/lstm-model-with-item-infor-fix-missing-last-item.ipynb create mode 100644 run_example/recbole-using-all-items-for-prediction.ipynb create mode 100644 run_example/sequential-model-fixed-missing-last-item.ipynb diff --git a/run_example/lstm-model-with-item-infor-fix-missing-last-item.ipynb b/run_example/lstm-model-with-item-infor-fix-missing-last-item.ipynb new file mode 100644 index 000000000..84cfd2125 --- /dev/null +++ b/run_example/lstm-model-with-item-infor-fix-missing-last-item.ipynb @@ -0,0 +1,3449 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "4a1a5970", + "metadata": { + "papermill": { + "duration": 0.067527, + "end_time": "2022-03-30T00:41:38.166135", + "exception": false, + "start_time": "2022-03-30T00:41:38.098608", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "# 0.Overview\n", + "\n", + "**Edit:**\n", + "\n", + "* In my previous notebooks([here](https://www.kaggle.com/code/astrung/lstm-sequential-modelwith-item-features-tutorial) and [here](https://www.kaggle.com/code/astrung/lstm-sequential-modelwith-item-features-tutorial)), we have used test_data with `full_sort_topk`,but due to the limit of full_sort_topk we have missed last item for submited recommendation. Someone asked me about how can use all items as input features for recommendation in this [comment](https://www.kaggle.com/code/astrung/recbole-lstm-sequential-for-recomendation-tutorial/comments#1723707). \n", + "* So i created a notebook [here](https://www.kaggle.com/code/astrung/recbole-using-all-items-for-prediction) for address there questions in detail, and this notebook is an improved of my [previous notebook](https://www.kaggle.com/code/astrung/lstm-sequential-modelwith-item-features-tutorial), applying our new function (using all item as input features without `full_sort_topk`) for this competition. In this notebook, we also use item features as input.\n", + "* If you only want to use interaction as input feature, please check this [notebook](https://www.kaggle.com/astrung/lstm-model-with-item-infor-fix-missing-last-item).\n", + "\n", + "- - -\n", + "\n", + "In previous [my notebook](https://www.kaggle.com/code/astrung/sequential-model-fixed-missing-last-item/), we tried to use GRU/LSTM model for testing effect of sequential model for recommendation with only iteration.\n", + "In this notebook, i showed how we can enhance sequential model with item features \n", + "\n", + "Due to memory limit and faster testing purpose, we will just use data in 2020.\n", + "\n", + "If you want to use with all of interactions in all time, i have created a new atomic dataset here for you: \n", + "\n", + "* only interations data: https://www.kaggle.com/astrung/hm-atomic-interation\n", + "* iterations + item features data: https://www.kaggle.com/astrung/hm-atomic-interation-with-item-feature \n", + "\n", + "We also have other limit: we only train model and predict with users who buy more than 40 items and items which is bought by more than 40 people.\n", + "\n", + "We will follow below steps for creating model:\n", + "\n", + "1. In order to use Recbole, we create atomic file from interaction data and item data\n", + "2. Because we only use Recbole model for predicting with users who buy more than 40 items, other users will need to fill by default recomendation items. We create most viewed items in last month as defautl recomendation\n", + "3. We create dataset and train model in recbole.\n", + "4. We create prediction result by trained model\n", + "5. We combine recomendation result from most viewed items in last month and Recbole predicted model.\n", + "\n", + "I will explain more detail in following cells.\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "f2a64764", + "metadata": { + "execution": { + "iopub.execute_input": "2022-03-30T00:41:38.271362Z", + "iopub.status.busy": "2022-03-30T00:41:38.270252Z", + "iopub.status.idle": "2022-03-30T00:41:59.220822Z", + "shell.execute_reply": "2022-03-30T00:41:59.220114Z", + "shell.execute_reply.started": "2022-03-20T03:28:15.953594Z" + }, + "papermill": { + "duration": 21.009084, + "end_time": "2022-03-30T00:41:59.221012", + "exception": false, + "start_time": "2022-03-30T00:41:38.211928", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Collecting recbole\r\n", + " Downloading recbole-1.0.1-py3-none-any.whl (2.0 MB)\r\n", + " |████████████████████████████████| 2.0 MB 533 kB/s \r\n", + "\u001b[?25hCollecting scipy==1.6.0\r\n", + " Downloading scipy-1.6.0-cp37-cp37m-manylinux1_x86_64.whl (27.4 MB)\r\n", + " |████████████████████████████████| 27.4 MB 100 kB/s \r\n", + "\u001b[?25hRequirement already satisfied: pandas>=1.0.5 in /opt/conda/lib/python3.7/site-packages (from recbole) (1.3.5)\r\n", + "Collecting colorlog==4.7.2\r\n", + " Downloading colorlog-4.7.2-py2.py3-none-any.whl (10 kB)\r\n", + "Requirement already satisfied: colorama==0.4.4 in /opt/conda/lib/python3.7/site-packages (from recbole) (0.4.4)\r\n", + "Requirement already satisfied: tqdm>=4.48.2 in /opt/conda/lib/python3.7/site-packages (from recbole) (4.62.3)\r\n", + "Requirement already satisfied: pyyaml>=5.1.0 in /opt/conda/lib/python3.7/site-packages (from recbole) (6.0)\r\n", + "Requirement already satisfied: scikit-learn>=0.23.2 in /opt/conda/lib/python3.7/site-packages (from recbole) (0.23.2)\r\n", + "Requirement already satisfied: torch>=1.7.0 in /opt/conda/lib/python3.7/site-packages (from recbole) (1.9.1)\r\n", + "Requirement already satisfied: numpy>=1.17.2 in /opt/conda/lib/python3.7/site-packages (from recbole) (1.20.3)\r\n", + "Requirement already satisfied: tensorboard>=2.5.0 in /opt/conda/lib/python3.7/site-packages (from recbole) (2.6.0)\r\n", + "Requirement already satisfied: python-dateutil>=2.7.3 in /opt/conda/lib/python3.7/site-packages (from pandas>=1.0.5->recbole) (2.8.2)\r\n", + "Requirement already satisfied: pytz>=2017.3 in /opt/conda/lib/python3.7/site-packages (from pandas>=1.0.5->recbole) (2021.3)\r\n", + "Requirement already satisfied: joblib>=0.11 in /opt/conda/lib/python3.7/site-packages (from scikit-learn>=0.23.2->recbole) (1.1.0)\r\n", + "Requirement already satisfied: threadpoolctl>=2.0.0 in /opt/conda/lib/python3.7/site-packages (from scikit-learn>=0.23.2->recbole) (3.0.0)\r\n", + "Requirement already satisfied: tensorboard-plugin-wit>=1.6.0 in /opt/conda/lib/python3.7/site-packages (from tensorboard>=2.5.0->recbole) (1.8.0)\r\n", + "Requirement already satisfied: grpcio>=1.24.3 in /opt/conda/lib/python3.7/site-packages (from tensorboard>=2.5.0->recbole) (1.43.0)\r\n", + "Requirement already satisfied: absl-py>=0.4 in /opt/conda/lib/python3.7/site-packages (from tensorboard>=2.5.0->recbole) (0.15.0)\r\n", + "Requirement already satisfied: tensorboard-data-server<0.7.0,>=0.6.0 in /opt/conda/lib/python3.7/site-packages (from tensorboard>=2.5.0->recbole) (0.6.1)\r\n", + "Requirement already satisfied: google-auth<2,>=1.6.3 in /opt/conda/lib/python3.7/site-packages (from tensorboard>=2.5.0->recbole) (1.35.0)\r\n", + "Requirement already satisfied: markdown>=2.6.8 in /opt/conda/lib/python3.7/site-packages (from tensorboard>=2.5.0->recbole) (3.3.6)\r\n", + "Requirement already satisfied: wheel>=0.26 in /opt/conda/lib/python3.7/site-packages (from tensorboard>=2.5.0->recbole) (0.37.0)\r\n", + "Requirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in /opt/conda/lib/python3.7/site-packages (from tensorboard>=2.5.0->recbole) (0.4.6)\r\n", + "Requirement already satisfied: setuptools>=41.0.0 in /opt/conda/lib/python3.7/site-packages (from tensorboard>=2.5.0->recbole) (59.5.0)\r\n", + "Requirement already satisfied: werkzeug>=0.11.15 in /opt/conda/lib/python3.7/site-packages (from tensorboard>=2.5.0->recbole) (2.0.2)\r\n", + "Requirement already satisfied: protobuf>=3.6.0 in /opt/conda/lib/python3.7/site-packages (from tensorboard>=2.5.0->recbole) (3.19.1)\r\n", + "Requirement already satisfied: requests<3,>=2.21.0 in /opt/conda/lib/python3.7/site-packages (from tensorboard>=2.5.0->recbole) (2.26.0)\r\n", + "Requirement already satisfied: typing-extensions in /opt/conda/lib/python3.7/site-packages (from torch>=1.7.0->recbole) (4.0.1)\r\n", + "Requirement already satisfied: six in /opt/conda/lib/python3.7/site-packages (from absl-py>=0.4->tensorboard>=2.5.0->recbole) (1.16.0)\r\n", + "Requirement already satisfied: pyasn1-modules>=0.2.1 in /opt/conda/lib/python3.7/site-packages (from google-auth<2,>=1.6.3->tensorboard>=2.5.0->recbole) (0.2.7)\r\n", + "Requirement already satisfied: cachetools<5.0,>=2.0.0 in /opt/conda/lib/python3.7/site-packages (from google-auth<2,>=1.6.3->tensorboard>=2.5.0->recbole) (4.2.4)\r\n", + "Requirement already satisfied: rsa<5,>=3.1.4 in /opt/conda/lib/python3.7/site-packages (from google-auth<2,>=1.6.3->tensorboard>=2.5.0->recbole) (4.8)\r\n", + "Requirement already satisfied: requests-oauthlib>=0.7.0 in /opt/conda/lib/python3.7/site-packages (from google-auth-oauthlib<0.5,>=0.4.1->tensorboard>=2.5.0->recbole) (1.3.0)\r\n", + "Requirement already satisfied: importlib-metadata>=4.4 in /opt/conda/lib/python3.7/site-packages (from markdown>=2.6.8->tensorboard>=2.5.0->recbole) (4.10.1)\r\n", + "Requirement already satisfied: idna<4,>=2.5 in /opt/conda/lib/python3.7/site-packages (from requests<3,>=2.21.0->tensorboard>=2.5.0->recbole) (3.1)\r\n", + "Requirement already satisfied: charset-normalizer~=2.0.0 in /opt/conda/lib/python3.7/site-packages (from requests<3,>=2.21.0->tensorboard>=2.5.0->recbole) (2.0.9)\r\n", + "Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.7/site-packages (from requests<3,>=2.21.0->tensorboard>=2.5.0->recbole) (2021.10.8)\r\n", + "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /opt/conda/lib/python3.7/site-packages (from requests<3,>=2.21.0->tensorboard>=2.5.0->recbole) (1.26.7)\r\n", + "Requirement already satisfied: zipp>=0.5 in /opt/conda/lib/python3.7/site-packages (from importlib-metadata>=4.4->markdown>=2.6.8->tensorboard>=2.5.0->recbole) (3.6.0)\r\n", + "Requirement already satisfied: pyasn1<0.5.0,>=0.4.6 in /opt/conda/lib/python3.7/site-packages (from pyasn1-modules>=0.2.1->google-auth<2,>=1.6.3->tensorboard>=2.5.0->recbole) (0.4.8)\r\n", + "Requirement already satisfied: oauthlib>=3.0.0 in /opt/conda/lib/python3.7/site-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tensorboard>=2.5.0->recbole) (3.1.1)\r\n", + "Installing collected packages: scipy, colorlog, recbole\r\n", + " Attempting uninstall: scipy\r\n", + " Found existing installation: scipy 1.7.3\r\n", + " Uninstalling scipy-1.7.3:\r\n", + " Successfully uninstalled scipy-1.7.3\r\n", + " Attempting uninstall: colorlog\r\n", + " Found existing installation: colorlog 6.6.0\r\n", + " Uninstalling colorlog-6.6.0:\r\n", + " Successfully uninstalled colorlog-6.6.0\r\n", + "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\r\n", + "yellowbrick 1.3.post1 requires numpy<1.20,>=1.16.0, but you have numpy 1.20.3 which is incompatible.\r\n", + "pdpbox 0.2.1 requires matplotlib==3.1.1, but you have matplotlib 3.5.1 which is incompatible.\r\n", + "imbalanced-learn 0.9.0 requires scikit-learn>=1.0.1, but you have scikit-learn 0.23.2 which is incompatible.\r\n", + "featuretools 1.4.1 requires numpy>=1.21.0, but you have numpy 1.20.3 which is incompatible.\r\n", + "arviz 0.11.4 requires typing-extensions<4,>=3.7.4.3, but you have typing-extensions 4.0.1 which is incompatible.\u001b[0m\r\n", + "Successfully installed colorlog-4.7.2 recbole-1.0.1 scipy-1.6.0\r\n", + "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\r\n" + ] + } + ], + "source": [ + "!pip install recbole" + ] + }, + { + "cell_type": "markdown", + "id": "d55f3b0c", + "metadata": { + "papermill": { + "duration": 0.069231, + "end_time": "2022-03-30T00:41:59.365650", + "exception": false, + "start_time": "2022-03-30T00:41:59.296419", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "# 1. Create atomic file" + ] + }, + { + "cell_type": "markdown", + "id": "d4ed7947", + "metadata": { + "papermill": { + "duration": 0.070207, + "end_time": "2022-03-30T00:41:59.507241", + "exception": false, + "start_time": "2022-03-30T00:41:59.437034", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "### 1.A create atomic of item features\n", + "we will create item features for feeding with iteration features into GRU4REC model " + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "a2d005de", + "metadata": { + "execution": { + "iopub.execute_input": "2022-03-30T00:41:59.698974Z", + "iopub.status.busy": "2022-03-30T00:41:59.695833Z", + "iopub.status.idle": "2022-03-30T00:42:01.232285Z", + "shell.execute_reply": "2022-03-30T00:42:01.232821Z", + "shell.execute_reply.started": "2022-03-20T03:28:32.486357Z" + }, + "papermill": { + "duration": 1.655417, + "end_time": "2022-03-30T00:42:01.232989", + "exception": false, + "start_time": "2022-03-30T00:41:59.577572", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
article_idproduct_codeprod_nameproduct_type_noproduct_type_nameproduct_group_namegraphical_appearance_nographical_appearance_namecolour_group_codecolour_group_name...department_nameindex_codeindex_nameindex_group_noindex_group_namesection_nosection_namegarment_group_nogarment_group_namedetail_desc
00108775015108775Strap top253Vest topGarment Upper body1010016Solid9Black...Jersey BasicALadieswear1Ladieswear16Womens Everyday Basics1002Jersey BasicJersey top with narrow shoulder straps.
10108775044108775Strap top253Vest topGarment Upper body1010016Solid10White...Jersey BasicALadieswear1Ladieswear16Womens Everyday Basics1002Jersey BasicJersey top with narrow shoulder straps.
20108775051108775Strap top (1)253Vest topGarment Upper body1010017Stripe11Off White...Jersey BasicALadieswear1Ladieswear16Womens Everyday Basics1002Jersey BasicJersey top with narrow shoulder straps.
30110065001110065OP T-shirt (Idro)306BraUnderwear1010016Solid9Black...Clean LingerieBLingeries/Tights1Ladieswear61Womens Lingerie1017Under-, NightwearMicrofibre T-shirt bra with underwired, moulde...
40110065002110065OP T-shirt (Idro)306BraUnderwear1010016Solid10White...Clean LingerieBLingeries/Tights1Ladieswear61Womens Lingerie1017Under-, NightwearMicrofibre T-shirt bra with underwired, moulde...
\n", + "

5 rows × 25 columns

\n", + "
" + ], + "text/plain": [ + " article_id product_code prod_name product_type_no \\\n", + "0 0108775015 108775 Strap top 253 \n", + "1 0108775044 108775 Strap top 253 \n", + "2 0108775051 108775 Strap top (1) 253 \n", + "3 0110065001 110065 OP T-shirt (Idro) 306 \n", + "4 0110065002 110065 OP T-shirt (Idro) 306 \n", + "\n", + " product_type_name product_group_name graphical_appearance_no \\\n", + "0 Vest top Garment Upper body 1010016 \n", + "1 Vest top Garment Upper body 1010016 \n", + "2 Vest top Garment Upper body 1010017 \n", + "3 Bra Underwear 1010016 \n", + "4 Bra Underwear 1010016 \n", + "\n", + " graphical_appearance_name colour_group_code colour_group_name ... \\\n", + "0 Solid 9 Black ... \n", + "1 Solid 10 White ... \n", + "2 Stripe 11 Off White ... \n", + "3 Solid 9 Black ... \n", + "4 Solid 10 White ... \n", + "\n", + " department_name index_code index_name index_group_no \\\n", + "0 Jersey Basic A Ladieswear 1 \n", + "1 Jersey Basic A Ladieswear 1 \n", + "2 Jersey Basic A Ladieswear 1 \n", + "3 Clean Lingerie B Lingeries/Tights 1 \n", + "4 Clean Lingerie B Lingeries/Tights 1 \n", + "\n", + " index_group_name section_no section_name garment_group_no \\\n", + "0 Ladieswear 16 Womens Everyday Basics 1002 \n", + "1 Ladieswear 16 Womens Everyday Basics 1002 \n", + "2 Ladieswear 16 Womens Everyday Basics 1002 \n", + "3 Ladieswear 61 Womens Lingerie 1017 \n", + "4 Ladieswear 61 Womens Lingerie 1017 \n", + "\n", + " garment_group_name detail_desc \n", + "0 Jersey Basic Jersey top with narrow shoulder straps. \n", + "1 Jersey Basic Jersey top with narrow shoulder straps. \n", + "2 Jersey Basic Jersey top with narrow shoulder straps. \n", + "3 Under-, Nightwear Microfibre T-shirt bra with underwired, moulde... \n", + "4 Under-, Nightwear Microfibre T-shirt bra with underwired, moulde... \n", + "\n", + "[5 rows x 25 columns]" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import pandas as pd\n", + "import gc\n", + "df = pd.read_csv(r\"/kaggle/input/h-and-m-personalized-fashion-recommendations/articles.csv\", dtype={'article_id': 'str'})\n", + "df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "2633bb84", + "metadata": { + "execution": { + "iopub.execute_input": "2022-03-30T00:42:01.392374Z", + "iopub.status.busy": "2022-03-30T00:42:01.391390Z", + "iopub.status.idle": "2022-03-30T00:42:01.553293Z", + "shell.execute_reply": "2022-03-30T00:42:01.552314Z", + "shell.execute_reply.started": "2022-03-20T03:28:33.491829Z" + }, + "papermill": { + "duration": 0.248859, + "end_time": "2022-03-30T00:42:01.553446", + "exception": false, + "start_time": "2022-03-30T00:42:01.304587", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "article_id\n", + "105542\n", + "product_code\n", + "47224\n", + "prod_name\n", + "45875\n", + "product_type_no\n", + "132\n", + "product_type_name\n", + "131\n", + "product_group_name\n", + "19\n", + "graphical_appearance_no\n", + "30\n", + "graphical_appearance_name\n", + "30\n", + "colour_group_code\n", + "50\n", + "colour_group_name\n", + "50\n", + "perceived_colour_value_id\n", + "8\n", + "perceived_colour_value_name\n", + "8\n", + "perceived_colour_master_id\n", + "20\n", + "perceived_colour_master_name\n", + "20\n", + "department_no\n", + "299\n", + "department_name\n", + "250\n", + "index_code\n", + "10\n", + "index_name\n", + "10\n", + "index_group_no\n", + "5\n", + "index_group_name\n", + "5\n", + "section_no\n", + "57\n", + "section_name\n", + "56\n", + "garment_group_no\n", + "21\n", + "garment_group_name\n", + "21\n", + "detail_desc\n", + "43405\n" + ] + } + ], + "source": [ + "for col in df.columns:\n", + " print(col)\n", + " print(len(pd.unique(df[col])))" + ] + }, + { + "cell_type": "markdown", + "id": "df90ab4d", + "metadata": { + "papermill": { + "duration": 0.074378, + "end_time": "2022-03-30T00:42:01.703487", + "exception": false, + "start_time": "2022-03-30T00:42:01.629109", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "#### we see so many couple of columns are [category_text, encoded_value]. So in order to avoid [Multicollinearity](https://link.springer.com/chapter/10.1007/978-0-585-25657-3_37), we will keep only one columns in each couple\n", + "We can see below couple of columns in item features, and we will keep one of them:\n", + "\n", + "* use product_type_no - skip product_type_name\n", + "* use graphical_appearance_no - skip graphical_appearance_name\n", + "* use colour_group_code - skip colour_group_name\n", + "* use perceived_colour_value_id - skip perceived_colour_value_name\n", + "* use perceived_colour_master_id - skip perceived_colour_master_name\n", + "* use index_code - skip index_name\n", + "* use index_group_no - skip index_group_name\n", + "* use section_no - skip section_name\n", + "* use garment_group_no - skip garment_group_name\n", + "* use product_code, skip product_name\n", + "* use department_no, skip department_name" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "d3defd2d", + "metadata": { + "execution": { + "iopub.execute_input": "2022-03-30T00:42:01.856833Z", + "iopub.status.busy": "2022-03-30T00:42:01.855783Z", + "iopub.status.idle": "2022-03-30T00:42:01.877762Z", + "shell.execute_reply": "2022-03-30T00:42:01.878722Z", + "shell.execute_reply.started": "2022-03-20T03:28:33.646379Z" + }, + "papermill": { + "duration": 0.10181, + "end_time": "2022-03-30T00:42:01.878893", + "exception": false, + "start_time": "2022-03-30T00:42:01.777083", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
article_idproduct_codeproduct_type_noproduct_group_namegraphical_appearance_nocolour_group_codeperceived_colour_value_idperceived_colour_master_iddepartment_noindex_codeindex_group_nosection_nogarment_group_no
00108775015108775253Garment Upper body10100169451676A1161002
10108775044108775253Garment Upper body101001610391676A1161002
20108775051108775253Garment Upper body101001711191676A1161002
30110065001110065306Underwear10100169451339B1611017
40110065002110065306Underwear101001610391339B1611017
\n", + "
" + ], + "text/plain": [ + " article_id product_code product_type_no product_group_name \\\n", + "0 0108775015 108775 253 Garment Upper body \n", + "1 0108775044 108775 253 Garment Upper body \n", + "2 0108775051 108775 253 Garment Upper body \n", + "3 0110065001 110065 306 Underwear \n", + "4 0110065002 110065 306 Underwear \n", + "\n", + " graphical_appearance_no colour_group_code perceived_colour_value_id \\\n", + "0 1010016 9 4 \n", + "1 1010016 10 3 \n", + "2 1010017 11 1 \n", + "3 1010016 9 4 \n", + "4 1010016 10 3 \n", + "\n", + " perceived_colour_master_id department_no index_code index_group_no \\\n", + "0 5 1676 A 1 \n", + "1 9 1676 A 1 \n", + "2 9 1676 A 1 \n", + "3 5 1339 B 1 \n", + "4 9 1339 B 1 \n", + "\n", + " section_no garment_group_no \n", + "0 16 1002 \n", + "1 16 1002 \n", + "2 16 1002 \n", + "3 61 1017 \n", + "4 61 1017 " + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df = df.drop(columns = ['product_type_name', 'graphical_appearance_name', 'colour_group_name', 'perceived_colour_value_name',\n", + " 'perceived_colour_master_name', 'index_name', 'index_group_name', 'section_name', \n", + " 'garment_group_name', 'prod_name', 'department_name', 'detail_desc'])\n", + "df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "7f4269f2", + "metadata": { + "execution": { + "iopub.execute_input": "2022-03-30T00:42:02.035677Z", + "iopub.status.busy": "2022-03-30T00:42:02.034319Z", + "iopub.status.idle": "2022-03-30T00:42:02.053128Z", + "shell.execute_reply": "2022-03-30T00:42:02.053834Z", + "shell.execute_reply.started": "2022-03-20T03:28:33.667553Z" + }, + "papermill": { + "duration": 0.101736, + "end_time": "2022-03-30T00:42:02.054009", + "exception": false, + "start_time": "2022-03-30T00:42:01.952273", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
item_id:tokenproduct_code:tokenproduct_type_no:floatproduct_group_name:token_seqgraphical_appearance_no:tokencolour_group_code:tokenperceived_colour_value_id:tokenperceived_colour_master_id:tokendepartment_no:tokenindex_code:tokenindex_group_no:tokensection_no:tokengarment_group_no:token
00108775015108775253Garment Upper body10100169451676A1161002
10108775044108775253Garment Upper body101001610391676A1161002
20108775051108775253Garment Upper body101001711191676A1161002
30110065001110065306Underwear10100169451339B1611017
40110065002110065306Underwear101001610391339B1611017
\n", + "
" + ], + "text/plain": [ + " item_id:token product_code:token product_type_no:float \\\n", + "0 0108775015 108775 253 \n", + "1 0108775044 108775 253 \n", + "2 0108775051 108775 253 \n", + "3 0110065001 110065 306 \n", + "4 0110065002 110065 306 \n", + "\n", + " product_group_name:token_seq graphical_appearance_no:token \\\n", + "0 Garment Upper body 1010016 \n", + "1 Garment Upper body 1010016 \n", + "2 Garment Upper body 1010017 \n", + "3 Underwear 1010016 \n", + "4 Underwear 1010016 \n", + "\n", + " colour_group_code:token perceived_colour_value_id:token \\\n", + "0 9 4 \n", + "1 10 3 \n", + "2 11 1 \n", + "3 9 4 \n", + "4 10 3 \n", + "\n", + " perceived_colour_master_id:token department_no:token index_code:token \\\n", + "0 5 1676 A \n", + "1 9 1676 A \n", + "2 9 1676 A \n", + "3 5 1339 B \n", + "4 9 1339 B \n", + "\n", + " index_group_no:token section_no:token garment_group_no:token \n", + "0 1 16 1002 \n", + "1 1 16 1002 \n", + "2 1 16 1002 \n", + "3 1 61 1017 \n", + "4 1 61 1017 " + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "temp = df.rename(\n", + " columns={'article_id': 'item_id:token', 'product_code': 'product_code:token', 'product_type_no': 'product_type_no:float',\n", + " 'product_group_name': 'product_group_name:token_seq', 'graphical_appearance_no': 'graphical_appearance_no:token', \n", + " 'colour_group_code': 'colour_group_code:token', 'perceived_colour_value_id': 'perceived_colour_value_id:token', \n", + " 'perceived_colour_master_id': 'perceived_colour_master_id:token', 'department_no': 'department_no:token', \n", + " 'index_code': 'index_code:token', 'index_group_no': 'index_group_no:token', 'section_no': 'section_no:token', \n", + " 'garment_group_no': 'garment_group_no:token'})\n", + "temp.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "3caa3a90", + "metadata": { + "execution": { + "iopub.execute_input": "2022-03-30T00:42:02.208337Z", + "iopub.status.busy": "2022-03-30T00:42:02.207395Z", + "iopub.status.idle": "2022-03-30T00:42:03.458221Z", + "shell.execute_reply": "2022-03-30T00:42:03.457583Z", + "shell.execute_reply.started": "2022-03-20T03:28:33.691403Z" + }, + "papermill": { + "duration": 1.330234, + "end_time": "2022-03-30T00:42:03.458383", + "exception": false, + "start_time": "2022-03-30T00:42:02.128149", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "!mkdir /kaggle/working/recbox_data\n", + "temp.to_csv(r'/kaggle/working/recbox_data/recbox_data.item', index=False, sep='\\t')" + ] + }, + { + "cell_type": "markdown", + "id": "3466077e", + "metadata": { + "papermill": { + "duration": 0.071684, + "end_time": "2022-03-30T00:42:03.602471", + "exception": false, + "start_time": "2022-03-30T00:42:03.530787", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "### 1.B create atomic of iteration features\n", + "we will create iteration features for GRU4REC model " + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "1d0b41f8", + "metadata": { + "execution": { + "iopub.execute_input": "2022-03-30T00:42:03.755739Z", + "iopub.status.busy": "2022-03-30T00:42:03.754738Z", + "iopub.status.idle": "2022-03-30T00:43:11.067732Z", + "shell.execute_reply": "2022-03-30T00:43:11.068281Z", + "shell.execute_reply.started": "2022-03-20T03:28:34.839528Z" + }, + "papermill": { + "duration": 67.393475, + "end_time": "2022-03-30T00:43:11.068452", + "exception": false, + "start_time": "2022-03-30T00:42:03.674977", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
t_datcustomer_idarticle_idpricesales_channel_id
02018-09-20000058a12d5b43e67d225668fa1f8d618c13dc232df0ca...06637130010.0508312
12018-09-20000058a12d5b43e67d225668fa1f8d618c13dc232df0ca...05415180230.0304922
22018-09-2000007d2de826758b65a93dd24ce629ed66842531df6699...05052210040.0152372
32018-09-2000007d2de826758b65a93dd24ce629ed66842531df6699...06856870030.0169322
42018-09-2000007d2de826758b65a93dd24ce629ed66842531df6699...06856870040.0169322
\n", + "
" + ], + "text/plain": [ + " t_dat customer_id article_id \\\n", + "0 2018-09-20 000058a12d5b43e67d225668fa1f8d618c13dc232df0ca... 0663713001 \n", + "1 2018-09-20 000058a12d5b43e67d225668fa1f8d618c13dc232df0ca... 0541518023 \n", + "2 2018-09-20 00007d2de826758b65a93dd24ce629ed66842531df6699... 0505221004 \n", + "3 2018-09-20 00007d2de826758b65a93dd24ce629ed66842531df6699... 0685687003 \n", + "4 2018-09-20 00007d2de826758b65a93dd24ce629ed66842531df6699... 0685687004 \n", + "\n", + " price sales_channel_id \n", + "0 0.050831 2 \n", + "1 0.030492 2 \n", + "2 0.015237 2 \n", + "3 0.016932 2 \n", + "4 0.016932 2 " + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df = pd.read_csv(r\"/kaggle/input/h-and-m-personalized-fashion-recommendations/transactions_train.csv\", \n", + " dtype={'article_id': 'str'})\n", + "df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "5530304c", + "metadata": { + "execution": { + "iopub.execute_input": "2022-03-30T00:43:11.675651Z", + "iopub.status.busy": "2022-03-30T00:43:11.674406Z", + "iopub.status.idle": "2022-03-30T00:43:17.613262Z", + "shell.execute_reply": "2022-03-30T00:43:17.612408Z", + "shell.execute_reply.started": "2022-03-20T03:30:08.593969Z" + }, + "papermill": { + "duration": 6.473316, + "end_time": "2022-03-30T00:43:17.613404", + "exception": false, + "start_time": "2022-03-30T00:43:11.140088", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
t_datcustomer_idarticle_idpricesales_channel_id
02018-09-20000058a12d5b43e67d225668fa1f8d618c13dc232df0ca...06637130010.0508312
12018-09-20000058a12d5b43e67d225668fa1f8d618c13dc232df0ca...05415180230.0304922
22018-09-2000007d2de826758b65a93dd24ce629ed66842531df6699...05052210040.0152372
32018-09-2000007d2de826758b65a93dd24ce629ed66842531df6699...06856870030.0169322
42018-09-2000007d2de826758b65a93dd24ce629ed66842531df6699...06856870040.0169322
..................
317883192020-09-22fff2282977442e327b45d8c89afde25617d00124d0f999...09295110010.0593052
317883202020-09-22fff2282977442e327b45d8c89afde25617d00124d0f999...08913220040.0423562
317883212020-09-22fff380805474b287b05cb2a7507b9a013482f7dd0bce0e...09183250010.0432031
317883222020-09-22fff4d3a8b1f3b60af93e78c30a7cb4cf75edaf2590d3e5...08334590020.0067631
317883232020-09-22fffef3b6b73545df065b521e19f64bf6fe93bfd450ab20...08985730030.0338812
\n", + "

31788324 rows × 5 columns

\n", + "
" + ], + "text/plain": [ + " t_dat customer_id \\\n", + "0 2018-09-20 000058a12d5b43e67d225668fa1f8d618c13dc232df0ca... \n", + "1 2018-09-20 000058a12d5b43e67d225668fa1f8d618c13dc232df0ca... \n", + "2 2018-09-20 00007d2de826758b65a93dd24ce629ed66842531df6699... \n", + "3 2018-09-20 00007d2de826758b65a93dd24ce629ed66842531df6699... \n", + "4 2018-09-20 00007d2de826758b65a93dd24ce629ed66842531df6699... \n", + "... ... ... \n", + "31788319 2020-09-22 fff2282977442e327b45d8c89afde25617d00124d0f999... \n", + "31788320 2020-09-22 fff2282977442e327b45d8c89afde25617d00124d0f999... \n", + "31788321 2020-09-22 fff380805474b287b05cb2a7507b9a013482f7dd0bce0e... \n", + "31788322 2020-09-22 fff4d3a8b1f3b60af93e78c30a7cb4cf75edaf2590d3e5... \n", + "31788323 2020-09-22 fffef3b6b73545df065b521e19f64bf6fe93bfd450ab20... \n", + "\n", + " article_id price sales_channel_id \n", + "0 0663713001 0.050831 2 \n", + "1 0541518023 0.030492 2 \n", + "2 0505221004 0.015237 2 \n", + "3 0685687003 0.016932 2 \n", + "4 0685687004 0.016932 2 \n", + "... ... ... ... \n", + "31788319 0929511001 0.059305 2 \n", + "31788320 0891322004 0.042356 2 \n", + "31788321 0918325001 0.043203 1 \n", + "31788322 0833459002 0.006763 1 \n", + "31788323 0898573003 0.033881 2 \n", + "\n", + "[31788324 rows x 5 columns]" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df['t_dat'] = pd.to_datetime(df['t_dat'], format=\"%Y-%m-%d\")\n", + "df" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "d0d23d57", + "metadata": { + "execution": { + "iopub.execute_input": "2022-03-30T00:43:17.875610Z", + "iopub.status.busy": "2022-03-30T00:43:17.874533Z", + "iopub.status.idle": "2022-03-30T00:43:18.680561Z", + "shell.execute_reply": "2022-03-30T00:43:18.681754Z", + "shell.execute_reply.started": "2022-03-20T03:30:14.348681Z" + }, + "papermill": { + "duration": 0.973379, + "end_time": "2022-03-30T00:43:18.681994", + "exception": false, + "start_time": "2022-03-30T00:43:17.708615", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
t_datcustomer_idarticle_idpricesales_channel_idtimestamp
02018-09-20000058a12d5b43e67d225668fa1f8d618c13dc232df0ca...06637130010.05083121537401600
12018-09-20000058a12d5b43e67d225668fa1f8d618c13dc232df0ca...05415180230.03049221537401600
22018-09-2000007d2de826758b65a93dd24ce629ed66842531df6699...05052210040.01523721537401600
32018-09-2000007d2de826758b65a93dd24ce629ed66842531df6699...06856870030.01693221537401600
42018-09-2000007d2de826758b65a93dd24ce629ed66842531df6699...06856870040.01693221537401600
\n", + "
" + ], + "text/plain": [ + " t_dat customer_id article_id \\\n", + "0 2018-09-20 000058a12d5b43e67d225668fa1f8d618c13dc232df0ca... 0663713001 \n", + "1 2018-09-20 000058a12d5b43e67d225668fa1f8d618c13dc232df0ca... 0541518023 \n", + "2 2018-09-20 00007d2de826758b65a93dd24ce629ed66842531df6699... 0505221004 \n", + "3 2018-09-20 00007d2de826758b65a93dd24ce629ed66842531df6699... 0685687003 \n", + "4 2018-09-20 00007d2de826758b65a93dd24ce629ed66842531df6699... 0685687004 \n", + "\n", + " price sales_channel_id timestamp \n", + "0 0.050831 2 1537401600 \n", + "1 0.030492 2 1537401600 \n", + "2 0.015237 2 1537401600 \n", + "3 0.016932 2 1537401600 \n", + "4 0.016932 2 1537401600 " + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import numpy as np\n", + "df['timestamp'] = df.t_dat.values.astype(np.int64) // 10 ** 9\n", + "df.head()" + ] + }, + { + "cell_type": "markdown", + "id": "04aad5a7", + "metadata": { + "papermill": { + "duration": 0.121777, + "end_time": "2022-03-30T00:43:18.925901", + "exception": false, + "start_time": "2022-03-30T00:43:18.804124", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "**We fill with data in only 2020(timestamp > > 1585620000) and create inter file**\n", + "For anyone need instruction about inter file, please check below links:\n", + "* https://recbole.io/docs/user_guide/data_intro.html\n", + "* https://recbole.io/docs/user_guide/data/atomic_files.html\n", + "\n", + "if you want a full of iterations without limiting timestamp, please check here:\n", + "* " + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "451144d2", + "metadata": { + "execution": { + "iopub.execute_input": "2022-03-30T00:43:19.168703Z", + "iopub.status.busy": "2022-03-30T00:43:19.167608Z", + "iopub.status.idle": "2022-03-30T00:43:20.860491Z", + "shell.execute_reply": "2022-03-30T00:43:20.861014Z", + "shell.execute_reply.started": "2022-03-20T03:30:14.866117Z" + }, + "papermill": { + "duration": 1.812346, + "end_time": "2022-03-30T00:43:20.861202", + "exception": false, + "start_time": "2022-03-30T00:43:19.048856", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
user_id:tokenitem_id:tokentimestamp:float
23934157000058a12d5b43e67d225668fa1f8d618c13dc232df0ca...07278080011585699200
23934158000058a12d5b43e67d225668fa1f8d618c13dc232df0ca...07278080071585699200
23934159000563485cbb7850b0a93c6606f89c5b961c6647d1bd48...05675320151585699200
23934160000563485cbb7850b0a93c6606f89c5b961c6647d1bd48...07061040091585699200
2393416100083cda041544b2fbb0e0d2905ad17da7cf1007526fb4...07835040041585699200
............
31788319fff2282977442e327b45d8c89afde25617d00124d0f999...09295110011600732800
31788320fff2282977442e327b45d8c89afde25617d00124d0f999...08913220041600732800
31788321fff380805474b287b05cb2a7507b9a013482f7dd0bce0e...09183250011600732800
31788322fff4d3a8b1f3b60af93e78c30a7cb4cf75edaf2590d3e5...08334590021600732800
31788323fffef3b6b73545df065b521e19f64bf6fe93bfd450ab20...08985730031600732800
\n", + "

7854167 rows × 3 columns

\n", + "
" + ], + "text/plain": [ + " user_id:token item_id:token \\\n", + "23934157 000058a12d5b43e67d225668fa1f8d618c13dc232df0ca... 0727808001 \n", + "23934158 000058a12d5b43e67d225668fa1f8d618c13dc232df0ca... 0727808007 \n", + "23934159 000563485cbb7850b0a93c6606f89c5b961c6647d1bd48... 0567532015 \n", + "23934160 000563485cbb7850b0a93c6606f89c5b961c6647d1bd48... 0706104009 \n", + "23934161 00083cda041544b2fbb0e0d2905ad17da7cf1007526fb4... 0783504004 \n", + "... ... ... \n", + "31788319 fff2282977442e327b45d8c89afde25617d00124d0f999... 0929511001 \n", + "31788320 fff2282977442e327b45d8c89afde25617d00124d0f999... 0891322004 \n", + "31788321 fff380805474b287b05cb2a7507b9a013482f7dd0bce0e... 0918325001 \n", + "31788322 fff4d3a8b1f3b60af93e78c30a7cb4cf75edaf2590d3e5... 0833459002 \n", + "31788323 fffef3b6b73545df065b521e19f64bf6fe93bfd450ab20... 0898573003 \n", + "\n", + " timestamp:float \n", + "23934157 1585699200 \n", + "23934158 1585699200 \n", + "23934159 1585699200 \n", + "23934160 1585699200 \n", + "23934161 1585699200 \n", + "... ... \n", + "31788319 1600732800 \n", + "31788320 1600732800 \n", + "31788321 1600732800 \n", + "31788322 1600732800 \n", + "31788323 1600732800 \n", + "\n", + "[7854167 rows x 3 columns]" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "temp = df[df['timestamp'] > 1585620000][['customer_id', 'article_id', 'timestamp']].rename(\n", + " columns={'customer_id': 'user_id:token', 'article_id': 'item_id:token', 'timestamp': 'timestamp:float'})\n", + "temp" + ] + }, + { + "cell_type": "markdown", + "id": "6c8a6b12", + "metadata": { + "papermill": { + "duration": 0.075369, + "end_time": "2022-03-30T00:43:21.011379", + "exception": false, + "start_time": "2022-03-30T00:43:20.936010", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "We save atomic file in dataset format for using with recbole" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "08430886", + "metadata": { + "execution": { + "iopub.execute_input": "2022-03-30T00:43:21.175635Z", + "iopub.status.busy": "2022-03-30T00:43:21.174627Z", + "iopub.status.idle": "2022-03-30T00:43:57.249592Z", + "shell.execute_reply": "2022-03-30T00:43:57.250218Z", + "shell.execute_reply.started": "2022-03-20T03:30:16.32046Z" + }, + "papermill": { + "duration": 36.163301, + "end_time": "2022-03-30T00:43:57.250409", + "exception": false, + "start_time": "2022-03-30T00:43:21.087108", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "160" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "temp.to_csv('/kaggle/working/recbox_data/recbox_data.inter', index=False, sep='\\t')\n", + "del temp\n", + "gc.collect()" + ] + }, + { + "cell_type": "markdown", + "id": "d3fd7cf5", + "metadata": { + "papermill": { + "duration": 0.076297, + "end_time": "2022-03-30T00:43:57.402842", + "exception": false, + "start_time": "2022-03-30T00:43:57.326545", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "# 2. We create defautl recomendation for user who can not be predicted by sequential model.\n", + "I use this approach in notebook: https://www.kaggle.com/hervind/h-m-faster-trending-products-weekly You can check it for more detail information. I will juse copy only code here" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "4950cf8b", + "metadata": { + "execution": { + "iopub.execute_input": "2022-03-30T00:43:57.559376Z", + "iopub.status.busy": "2022-03-30T00:43:57.558307Z", + "iopub.status.idle": "2022-03-30T00:43:57.561403Z", + "shell.execute_reply": "2022-03-30T00:43:57.560913Z", + "shell.execute_reply.started": "2022-03-20T03:30:50.636671Z" + }, + "papermill": { + "duration": 0.083973, + "end_time": "2022-03-30T00:43:57.561578", + "exception": false, + "start_time": "2022-03-30T00:43:57.477605", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "import os\n", + "import numpy as np\n", + "import pandas as pd" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "04e3fb11", + "metadata": { + "execution": { + "iopub.execute_input": "2022-03-30T00:43:57.720033Z", + "iopub.status.busy": "2022-03-30T00:43:57.719431Z", + "iopub.status.idle": "2022-03-30T00:44:50.466451Z", + "shell.execute_reply": "2022-03-30T00:44:50.467107Z", + "shell.execute_reply.started": "2022-03-20T03:30:50.642822Z" + }, + "papermill": { + "duration": 52.830362, + "end_time": "2022-03-30T00:44:50.467278", + "exception": false, + "start_time": "2022-03-30T00:43:57.636916", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "((1371980, 2), (1371980, 2), (1371980, 2))" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sub0 = pd.read_csv('../input/hm-pre-recommendation/submissio_byfone_chris.csv').sort_values('customer_id').reset_index(drop=True)\n", + "sub1 = pd.read_csv('../input/hm-pre-recommendation/submission_trending.csv').sort_values('customer_id').reset_index(drop=True)\n", + "sub2 = pd.read_csv('../input/hm-pre-recommendation/submission_exponential_decay.csv').sort_values('customer_id').reset_index(drop=True)\n", + "\n", + "sub0.shape, sub1.shape, sub2.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "0c14ce6e", + "metadata": { + "execution": { + "iopub.execute_input": "2022-03-30T00:44:50.626546Z", + "iopub.status.busy": "2022-03-30T00:44:50.625317Z", + "iopub.status.idle": "2022-03-30T00:44:50.887372Z", + "shell.execute_reply": "2022-03-30T00:44:50.886795Z", + "shell.execute_reply.started": "2022-03-20T03:31:44.176782Z" + }, + "papermill": { + "duration": 0.344738, + "end_time": "2022-03-30T00:44:50.887513", + "exception": false, + "start_time": "2022-03-30T00:44:50.542775", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
customer_idprediction0prediction1prediction2
000000dbacae5abe5e23885899a1fa44253a17956c6d1c3...0568601043 0568601006 0656719005 0745232001 07...0568601043 0568601006 0656719005 0745232001 07...0568601043 0924243001 0924243002 0918522001 07...
10000423b00ade91418cceaf3b26c6af3dd342b51fd051e...0826211002 0800436010 0739590027 0723529001 08...0826211002 0800436010 0739590027 0723529001 08...0924243001 0924243002 0918522001 0751471001 04...
2000058a12d5b43e67d225668fa1f8d618c13dc232df0ca...0794321007 0852643001 0852643003 0858883002 07...0794321007 0852643001 0852643003 0858883002 07...0794321007 0924243001 0924243002 0918522001 07...
300005ca1c9ed5f5146b52ac8639a40ca9d57aeff4d1bd2...0448509014 0573085028 0751471001 0706016001 06...0448509014 0573085028 0751471001 0706016001 06...0924243001 0924243002 0918522001 0751471001 04...
400006413d8573cd20ed7128e53b7b13819fe5cfc2d801f...0730683050 0791587015 0896152002 0818320001 09...0730683050 0791587015 0896152002 0818320001 09...0924243001 0924243002 0918522001 0751471001 04...
\n", + "
" + ], + "text/plain": [ + " customer_id \\\n", + "0 00000dbacae5abe5e23885899a1fa44253a17956c6d1c3... \n", + "1 0000423b00ade91418cceaf3b26c6af3dd342b51fd051e... \n", + "2 000058a12d5b43e67d225668fa1f8d618c13dc232df0ca... \n", + "3 00005ca1c9ed5f5146b52ac8639a40ca9d57aeff4d1bd2... \n", + "4 00006413d8573cd20ed7128e53b7b13819fe5cfc2d801f... \n", + "\n", + " prediction0 \\\n", + "0 0568601043 0568601006 0656719005 0745232001 07... \n", + "1 0826211002 0800436010 0739590027 0723529001 08... \n", + "2 0794321007 0852643001 0852643003 0858883002 07... \n", + "3 0448509014 0573085028 0751471001 0706016001 06... \n", + "4 0730683050 0791587015 0896152002 0818320001 09... \n", + "\n", + " prediction1 \\\n", + "0 0568601043 0568601006 0656719005 0745232001 07... \n", + "1 0826211002 0800436010 0739590027 0723529001 08... \n", + "2 0794321007 0852643001 0852643003 0858883002 07... \n", + "3 0448509014 0573085028 0751471001 0706016001 06... \n", + "4 0730683050 0791587015 0896152002 0818320001 09... \n", + "\n", + " prediction2 \n", + "0 0568601043 0924243001 0924243002 0918522001 07... \n", + "1 0924243001 0924243002 0918522001 0751471001 04... \n", + "2 0794321007 0924243001 0924243002 0918522001 07... \n", + "3 0924243001 0924243002 0918522001 0751471001 04... \n", + "4 0924243001 0924243002 0918522001 0751471001 04... " + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sub0.columns = ['customer_id', 'prediction0']\n", + "sub0['prediction1'] = sub1['prediction']\n", + "sub0['prediction2'] = sub2['prediction']\n", + "del sub1, sub2\n", + "gc.collect()\n", + "sub0.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "5d85f39e", + "metadata": { + "execution": { + "iopub.execute_input": "2022-03-30T00:44:51.043852Z", + "iopub.status.busy": "2022-03-30T00:44:51.042783Z", + "iopub.status.idle": "2022-03-30T00:48:22.448610Z", + "shell.execute_reply": "2022-03-30T00:48:22.449190Z", + "shell.execute_reply.started": "2022-03-20T03:31:44.407999Z" + }, + "papermill": { + "duration": 211.486514, + "end_time": "2022-03-30T00:48:22.449364", + "exception": false, + "start_time": "2022-03-30T00:44:50.962850", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
customer_idprediction0prediction1prediction2prediction
000000dbacae5abe5e23885899a1fa44253a17956c6d1c3...0568601043 0568601006 0656719005 0745232001 07...0568601043 0568601006 0656719005 0745232001 07...0568601043 0924243001 0924243002 0918522001 07...0568601043 0568601006 0656719005 0745232001 09...
10000423b00ade91418cceaf3b26c6af3dd342b51fd051e...0826211002 0800436010 0739590027 0723529001 08...0826211002 0800436010 0739590027 0723529001 08...0924243001 0924243002 0918522001 0751471001 04...0826211002 0800436010 0924243001 0739590027 07...
2000058a12d5b43e67d225668fa1f8d618c13dc232df0ca...0794321007 0852643001 0852643003 0858883002 07...0794321007 0852643001 0852643003 0858883002 07...0794321007 0924243001 0924243002 0918522001 07...0794321007 0852643001 0852643003 0858883002 09...
300005ca1c9ed5f5146b52ac8639a40ca9d57aeff4d1bd2...0448509014 0573085028 0751471001 0706016001 06...0448509014 0573085028 0751471001 0706016001 06...0924243001 0924243002 0918522001 0751471001 04...0448509014 0573085028 0924243001 0751471001 07...
400006413d8573cd20ed7128e53b7b13819fe5cfc2d801f...0730683050 0791587015 0896152002 0818320001 09...0730683050 0791587015 0896152002 0818320001 09...0924243001 0924243002 0918522001 0751471001 04...0730683050 0791587015 0924243001 0896152002 08...
\n", + "
" + ], + "text/plain": [ + " customer_id \\\n", + "0 00000dbacae5abe5e23885899a1fa44253a17956c6d1c3... \n", + "1 0000423b00ade91418cceaf3b26c6af3dd342b51fd051e... \n", + "2 000058a12d5b43e67d225668fa1f8d618c13dc232df0ca... \n", + "3 00005ca1c9ed5f5146b52ac8639a40ca9d57aeff4d1bd2... \n", + "4 00006413d8573cd20ed7128e53b7b13819fe5cfc2d801f... \n", + "\n", + " prediction0 \\\n", + "0 0568601043 0568601006 0656719005 0745232001 07... \n", + "1 0826211002 0800436010 0739590027 0723529001 08... \n", + "2 0794321007 0852643001 0852643003 0858883002 07... \n", + "3 0448509014 0573085028 0751471001 0706016001 06... \n", + "4 0730683050 0791587015 0896152002 0818320001 09... \n", + "\n", + " prediction1 \\\n", + "0 0568601043 0568601006 0656719005 0745232001 07... \n", + "1 0826211002 0800436010 0739590027 0723529001 08... \n", + "2 0794321007 0852643001 0852643003 0858883002 07... \n", + "3 0448509014 0573085028 0751471001 0706016001 06... \n", + "4 0730683050 0791587015 0896152002 0818320001 09... \n", + "\n", + " prediction2 \\\n", + "0 0568601043 0924243001 0924243002 0918522001 07... \n", + "1 0924243001 0924243002 0918522001 0751471001 04... \n", + "2 0794321007 0924243001 0924243002 0918522001 07... \n", + "3 0924243001 0924243002 0918522001 0751471001 04... \n", + "4 0924243001 0924243002 0918522001 0751471001 04... \n", + "\n", + " prediction \n", + "0 0568601043 0568601006 0656719005 0745232001 09... \n", + "1 0826211002 0800436010 0924243001 0739590027 07... \n", + "2 0794321007 0852643001 0852643003 0858883002 09... \n", + "3 0448509014 0573085028 0924243001 0751471001 07... \n", + "4 0730683050 0791587015 0924243001 0896152002 08... " + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "def cust_blend(dt, W = [1,1,1]):\n", + " #Global ensemble weights\n", + " #W = [1.15,0.95,0.85]\n", + " \n", + " #Create a list of all model predictions\n", + " REC = []\n", + " REC.append(dt['prediction0'].split())\n", + " REC.append(dt['prediction1'].split())\n", + " REC.append(dt['prediction2'].split())\n", + " \n", + " #Create a dictionary of items recommended. \n", + " #Assign a weight according the order of appearance and multiply by global weights\n", + " res = {}\n", + " for M in range(len(REC)):\n", + " for n, v in enumerate(REC[M]):\n", + " if v in res:\n", + " res[v] += (W[M]/(n+1))\n", + " else:\n", + " res[v] = (W[M]/(n+1))\n", + " \n", + " # Sort dictionary by item weights\n", + " res = list(dict(sorted(res.items(), key=lambda item: -item[1])).keys())\n", + " \n", + " # Return the top 12 itens only\n", + " return ' '.join(res[:12])\n", + "\n", + "sub0['prediction'] = sub0.apply(cust_blend, W = [1.05,1.00,0.95], axis=1)\n", + "sub0.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "26885807", + "metadata": { + "execution": { + "iopub.execute_input": "2022-03-30T00:48:22.718278Z", + "iopub.status.busy": "2022-03-30T00:48:22.712658Z", + "iopub.status.idle": "2022-03-30T00:48:33.560171Z", + "shell.execute_reply": "2022-03-30T00:48:33.559554Z", + "shell.execute_reply.started": "2022-03-20T03:34:31.621742Z" + }, + "papermill": { + "duration": 11.032183, + "end_time": "2022-03-30T00:48:33.560355", + "exception": false, + "start_time": "2022-03-30T00:48:22.528172", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "del sub0['prediction0']\n", + "del sub0['prediction1']\n", + "del sub0['prediction2']\n", + "gc.collect()\n", + "sub0.to_csv(f'submission.csv', index=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "31fc4d9b", + "metadata": { + "execution": { + "iopub.execute_input": "2022-03-30T00:48:33.809587Z", + "iopub.status.busy": "2022-03-30T00:48:33.808573Z", + "iopub.status.idle": "2022-03-30T00:48:33.812129Z", + "shell.execute_reply": "2022-03-30T00:48:33.812627Z", + "shell.execute_reply.started": "2022-03-20T03:34:42.724028Z" + }, + "papermill": { + "duration": 0.17302, + "end_time": "2022-03-30T00:48:33.812807", + "exception": false, + "start_time": "2022-03-30T00:48:33.639787", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "21" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "del sub0\n", + "del df\n", + "gc.collect()" + ] + }, + { + "cell_type": "markdown", + "id": "ccca9df9", + "metadata": { + "papermill": { + "duration": 0.08023, + "end_time": "2022-03-30T00:48:33.974083", + "exception": false, + "start_time": "2022-03-30T00:48:33.893853", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "# 3. Create dataset and train model with Recbole\n", + "\n", + "For anyone need instruction document, please check this link: https://recbole.io/docs/user_guide/usage/use_modules.html" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "f0ac4e7a", + "metadata": { + "execution": { + "iopub.execute_input": "2022-03-30T00:48:34.139603Z", + "iopub.status.busy": "2022-03-30T00:48:34.138634Z", + "iopub.status.idle": "2022-03-30T00:48:37.030761Z", + "shell.execute_reply": "2022-03-30T00:48:37.029689Z", + "shell.execute_reply.started": "2022-03-20T03:34:42.811908Z" + }, + "papermill": { + "duration": 2.978139, + "end_time": "2022-03-30T00:48:37.030966", + "exception": false, + "start_time": "2022-03-30T00:48:34.052827", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "import logging\n", + "from logging import getLogger\n", + "from recbole.config import Config\n", + "from recbole.data import create_dataset, data_preparation\n", + "from recbole.model.sequential_recommender import GRU4RecF\n", + "from recbole.trainer import Trainer\n", + "from recbole.utils import init_seed, init_logger" + ] + }, + { + "cell_type": "markdown", + "id": "cb2a6624", + "metadata": { + "papermill": { + "duration": 0.080352, + "end_time": "2022-03-30T00:48:37.191738", + "exception": false, + "start_time": "2022-03-30T00:48:37.111386", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "for limiting memory and time traning, we will filter for only using user who bought more than 40 items and item which is sold more than 40 times. If you want to train with more data, please change below config\n", + "* user_inter_num_interval\n", + "* item_inter_num_interval" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "ba356668", + "metadata": { + "execution": { + "iopub.execute_input": "2022-03-30T00:48:37.363904Z", + "iopub.status.busy": "2022-03-30T00:48:37.362514Z", + "iopub.status.idle": "2022-03-30T00:48:38.115093Z", + "shell.execute_reply": "2022-03-30T00:48:37.856983Z", + "shell.execute_reply.started": "2022-03-20T03:34:45.136018Z" + }, + "papermill": { + "duration": 0.842333, + "end_time": "2022-03-30T00:48:38.115271", + "exception": false, + "start_time": "2022-03-30T00:48:37.272938", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "General Hyper Parameters:\n", + "gpu_id = 0\n", + "use_gpu = True\n", + "seed = 2020\n", + "state = INFO\n", + "reproducibility = True\n", + "data_path = /kaggle/working/recbox_data\n", + "checkpoint_dir = saved\n", + "show_progress = True\n", + "save_dataset = False\n", + "dataset_save_path = None\n", + "save_dataloaders = False\n", + "dataloaders_save_path = None\n", + "log_wandb = False\n", + "\n", + "Training Hyper Parameters:\n", + "epochs = 30\n", + "train_batch_size = 2048\n", + "learner = adam\n", + "learning_rate = 0.001\n", + "neg_sampling = None\n", + "eval_step = 1\n", + "stopping_step = 10\n", + "clip_grad_norm = None\n", + "weight_decay = 0.0\n", + "loss_decimal_place = 4\n", + "\n", + "Evaluation Hyper Parameters:\n", + "eval_args = {'split': {'RS': [10, 0, 0]}, 'group_by': 'user', 'order': 'TO', 'mode': 'full'}\n", + "repeatable = True\n", + "metrics = ['Recall', 'MRR', 'NDCG', 'Hit', 'Precision']\n", + "topk = [10]\n", + "valid_metric = MRR@10\n", + "valid_metric_bigger = True\n", + "eval_batch_size = 4096\n", + "metric_decimal_place = 4\n", + "\n", + "Dataset Hyper Parameters:\n", + "field_separator = \t\n", + "seq_separator = \n", + "USER_ID_FIELD = user_id\n", + "ITEM_ID_FIELD = item_id\n", + "RATING_FIELD = rating\n", + "TIME_FIELD = timestamp\n", + "seq_len = None\n", + "LABEL_FIELD = label\n", + "threshold = None\n", + "NEG_PREFIX = neg_\n", + "load_col = {'inter': ['user_id', 'item_id', 'timestamp'], 'item': ['item_id', 'product_code', 'product_type_no', 'product_group_name', 'graphical_appearance_no', 'colour_group_code', 'perceived_colour_value_id', 'perceived_colour_master_id', 'department_no', 'index_code', 'index_group_no', 'section_no', 'garment_group_no']}\n", + "unload_col = None\n", + "unused_col = None\n", + "additional_feat_suffix = None\n", + "rm_dup_inter = None\n", + "val_interval = None\n", + "filter_inter_by_user_or_item = True\n", + "user_inter_num_interval = [40,inf)\n", + "item_inter_num_interval = [40,inf)\n", + "alias_of_user_id = None\n", + "alias_of_item_id = None\n", + "alias_of_entity_id = None\n", + "alias_of_relation_id = None\n", + "preload_weight = None\n", + "normalize_field = None\n", + "normalize_all = None\n", + "ITEM_LIST_LENGTH_FIELD = item_length\n", + "LIST_SUFFIX = _list\n", + "MAX_ITEM_LIST_LENGTH = 50\n", + "POSITION_FIELD = position_id\n", + "HEAD_ENTITY_ID_FIELD = head_id\n", + "TAIL_ENTITY_ID_FIELD = tail_id\n", + "RELATION_ID_FIELD = relation_id\n", + "ENTITY_ID_FIELD = entity_id\n", + "benchmark_filename = None\n", + "\n", + "Other Hyper Parameters: \n", + "wandb_project = recbole\n", + "require_pow = False\n", + "embedding_size = 64\n", + "hidden_size = 128\n", + "num_layers = 1\n", + "dropout_prob = 0.3\n", + "selected_features = ['product_code', 'product_type_no', 'product_group_name', 'graphical_appearance_no', 'colour_group_code', 'perceived_colour_value_id', 'perceived_colour_master_id', 'department_no', 'index_code', 'index_group_no', 'section_no', 'garment_group_no']\n", + "pooling_mode = sum\n", + "loss_type = CE\n", + "MODEL_TYPE = ModelType.SEQUENTIAL\n", + "MODEL_INPUT_TYPE = InputType.POINTWISE\n", + "eval_type = EvaluatorType.RANKING\n", + "device = cuda\n", + "train_neg_sample_args = {'strategy': 'none'}\n", + "eval_neg_sample_args = {'strategy': 'full', 'distribution': 'uniform'}\n", + "\n", + "\n" + ] + } + ], + "source": [ + "parameter_dict = {\n", + " 'data_path': '/kaggle/working',\n", + " 'USER_ID_FIELD': 'user_id',\n", + " 'ITEM_ID_FIELD': 'item_id',\n", + " 'TIME_FIELD': 'timestamp',\n", + " 'user_inter_num_interval': \"[40,inf)\",\n", + " 'item_inter_num_interval': \"[40,inf)\",\n", + " 'load_col': {'inter': ['user_id', 'item_id', 'timestamp'],\n", + " 'item': ['item_id', 'product_code', 'product_type_no', 'product_group_name', 'graphical_appearance_no',\n", + " 'colour_group_code', 'perceived_colour_value_id', 'perceived_colour_master_id',\n", + " 'department_no', 'index_code', 'index_group_no', 'section_no', 'garment_group_no']\n", + " },\n", + " 'selected_features': ['product_code', 'product_type_no', 'product_group_name', 'graphical_appearance_no',\n", + " 'colour_group_code', 'perceived_colour_value_id', 'perceived_colour_master_id',\n", + " 'department_no', 'index_code', 'index_group_no', 'section_no', 'garment_group_no'],\n", + " 'neg_sampling': None,\n", + " 'epochs': 30,\n", + " 'eval_args': {\n", + " 'split': {'RS': [10, 0, 0]},\n", + " 'group_by': 'user',\n", + " 'order': 'TO',\n", + " 'mode': 'full'}\n", + "}\n", + "\n", + "config = Config(model='GRU4RecF', dataset='recbox_data', config_dict=parameter_dict)\n", + "\n", + "# init random seed\n", + "init_seed(config['seed'], config['reproducibility'])\n", + "\n", + "# logger initialization\n", + "init_logger(config)\n", + "logger = getLogger()\n", + "# Create handlers\n", + "c_handler = logging.StreamHandler()\n", + "c_handler.setLevel(logging.INFO)\n", + "logger.addHandler(c_handler)\n", + "\n", + "# write config info into log\n", + "logger.info(config)" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "476c6883", + "metadata": { + "execution": { + "iopub.execute_input": "2022-03-30T00:48:38.561771Z", + "iopub.status.busy": "2022-03-30T00:48:38.560009Z", + "iopub.status.idle": "2022-03-30T00:50:10.595639Z", + "shell.execute_reply": "2022-03-30T00:50:10.579753Z", + "shell.execute_reply.started": "2022-03-20T03:34:45.454951Z" + }, + "papermill": { + "duration": 92.260334, + "end_time": "2022-03-30T00:50:10.595802", + "exception": false, + "start_time": "2022-03-30T00:48:38.335468", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "recbox_data\n", + "The number of users: 15459\n", + "Average actions of users: 59.21956268598784\n", + "The number of items: 7330\n", + "Average actions of items: 124.9032610178742\n", + "The number of inters: 915416\n", + "The sparsity of the dataset: 99.19214553975321%\n", + "Remain Fields: ['user_id', 'item_id', 'timestamp', 'product_code', 'product_type_no', 'product_group_name', 'graphical_appearance_no', 'colour_group_code', 'perceived_colour_value_id', 'perceived_colour_master_id', 'department_no', 'index_code', 'index_group_no', 'section_no', 'garment_group_no']\n" + ] + } + ], + "source": [ + "dataset = create_dataset(config)\n", + "logger.info(dataset)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "e27072f2", + "metadata": { + "execution": { + "iopub.execute_input": "2022-03-30T00:50:11.033273Z", + "iopub.status.busy": "2022-03-30T00:50:11.032541Z", + "iopub.status.idle": "2022-03-30T00:50:30.653716Z", + "shell.execute_reply": "2022-03-30T00:50:30.648118Z", + "shell.execute_reply.started": "2022-03-20T03:36:01.448124Z" + }, + "papermill": { + "duration": 19.845623, + "end_time": "2022-03-30T00:50:30.653883", + "exception": false, + "start_time": "2022-03-30T00:50:10.808260", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[Training]: train_batch_size = [2048] negative sampling: [None]\n", + "[Evaluation]: eval_batch_size = [4096] eval_args: [{'split': {'RS': [10, 0, 0]}, 'group_by': 'user', 'order': 'TO', 'mode': 'full'}]\n" + ] + } + ], + "source": [ + "# dataset splitting\n", + "train_data, valid_data, test_data = data_preparation(config, dataset)" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "89dbcdd3", + "metadata": { + "execution": { + "iopub.execute_input": "2022-03-30T00:50:31.118124Z", + "iopub.status.busy": "2022-03-30T00:50:31.117440Z", + "iopub.status.idle": "2022-03-30T01:14:14.075339Z", + "shell.execute_reply": "2022-03-30T01:14:14.074553Z", + "shell.execute_reply.started": "2022-03-20T03:36:17.473947Z" + }, + "papermill": { + "duration": 1423.193275, + "end_time": "2022-03-30T01:14:14.075541", + "exception": false, + "start_time": "2022-03-30T00:50:30.882266", + "status": "completed" + }, + "scrolled": true, + "tags": [] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "GRU4RecF(\n", + " (item_embedding): Embedding(7330, 64, padding_idx=0)\n", + " (feature_embed_layer): FeatureSeqEmbLayer(\n", + " (token_embedding_table): ModuleDict(\n", + " (item): FMEmbedding(\n", + " (embedding): Embedding(3935, 64)\n", + " )\n", + " )\n", + " (float_embedding_table): ModuleDict(\n", + " (item): Embedding(1, 64)\n", + " )\n", + " (token_seq_embedding_table): ModuleDict(\n", + " (item): ModuleList(\n", + " (0): Embedding(16, 64)\n", + " )\n", + " )\n", + " )\n", + " (item_gru_layers): GRU(64, 128, bias=False, batch_first=True)\n", + " (feature_gru_layers): GRU(768, 128, bias=False, batch_first=True)\n", + " (dense_layer): Linear(in_features=256, out_features=64, bias=True)\n", + " (dropout): Dropout(p=0.3, inplace=False)\n", + " (loss_fct): CrossEntropyLoss()\n", + ")\n", + "Trainable parameters: 1156288\n", + "epoch 0 training [time: 47.65s, train loss: 3642.3637]\n", + "Saving current: saved/GRU4RecF-Mar-30-2022_00-50-40.pth\n", + "epoch 1 training [time: 44.34s, train loss: 3390.6134]\n", + "Saving current: saved/GRU4RecF-Mar-30-2022_00-50-40.pth\n", + "epoch 2 training [time: 44.25s, train loss: 3250.4472]\n", + "Saving current: saved/GRU4RecF-Mar-30-2022_00-50-40.pth\n", + "epoch 3 training [time: 44.39s, train loss: 3163.3735]\n", + "Saving current: saved/GRU4RecF-Mar-30-2022_00-50-40.pth\n", + "epoch 4 training [time: 44.24s, train loss: 3099.2533]\n", + "Saving current: saved/GRU4RecF-Mar-30-2022_00-50-40.pth\n", + "epoch 5 training [time: 44.37s, train loss: 3044.1074]\n", + "Saving current: saved/GRU4RecF-Mar-30-2022_00-50-40.pth\n", + "epoch 6 training [time: 44.19s, train loss: 2998.5542]\n", + "Saving current: saved/GRU4RecF-Mar-30-2022_00-50-40.pth\n", + "epoch 7 training [time: 44.22s, train loss: 2962.4046]\n", + "Saving current: saved/GRU4RecF-Mar-30-2022_00-50-40.pth\n", + "epoch 8 training [time: 44.27s, train loss: 2932.5592]\n", + "Saving current: saved/GRU4RecF-Mar-30-2022_00-50-40.pth\n", + "epoch 9 training [time: 44.35s, train loss: 2907.5308]\n", + "Saving current: saved/GRU4RecF-Mar-30-2022_00-50-40.pth\n", + "epoch 10 training [time: 44.34s, train loss: 2885.6282]\n", + "Saving current: saved/GRU4RecF-Mar-30-2022_00-50-40.pth\n", + "epoch 11 training [time: 44.20s, train loss: 2867.5368]\n", + "Saving current: saved/GRU4RecF-Mar-30-2022_00-50-40.pth\n", + "epoch 12 training [time: 44.19s, train loss: 2850.5957]\n", + "Saving current: saved/GRU4RecF-Mar-30-2022_00-50-40.pth\n", + "epoch 13 training [time: 44.52s, train loss: 2836.0930]\n", + "Saving current: saved/GRU4RecF-Mar-30-2022_00-50-40.pth\n", + "epoch 14 training [time: 44.21s, train loss: 2822.5238]\n", + "Saving current: saved/GRU4RecF-Mar-30-2022_00-50-40.pth\n", + "epoch 15 training [time: 44.39s, train loss: 2811.0895]\n", + "Saving current: saved/GRU4RecF-Mar-30-2022_00-50-40.pth\n", + "epoch 16 training [time: 44.43s, train loss: 2799.8698]\n", + "Saving current: saved/GRU4RecF-Mar-30-2022_00-50-40.pth\n", + "epoch 17 training [time: 44.28s, train loss: 2790.4907]\n", + "Saving current: saved/GRU4RecF-Mar-30-2022_00-50-40.pth\n", + "epoch 18 training [time: 44.26s, train loss: 2781.8785]\n", + "Saving current: saved/GRU4RecF-Mar-30-2022_00-50-40.pth\n", + "epoch 19 training [time: 44.24s, train loss: 2774.2283]\n", + "Saving current: saved/GRU4RecF-Mar-30-2022_00-50-40.pth\n", + "epoch 20 training [time: 44.68s, train loss: 2766.8078]\n", + "Saving current: saved/GRU4RecF-Mar-30-2022_00-50-40.pth\n", + "epoch 21 training [time: 44.25s, train loss: 2760.0341]\n", + "Saving current: saved/GRU4RecF-Mar-30-2022_00-50-40.pth\n", + "epoch 22 training [time: 44.42s, train loss: 2753.8305]\n", + "Saving current: saved/GRU4RecF-Mar-30-2022_00-50-40.pth\n", + "epoch 23 training [time: 44.29s, train loss: 2748.0924]\n", + "Saving current: saved/GRU4RecF-Mar-30-2022_00-50-40.pth\n", + "epoch 24 training [time: 44.33s, train loss: 2742.5462]\n", + "Saving current: saved/GRU4RecF-Mar-30-2022_00-50-40.pth\n", + "epoch 25 training [time: 44.21s, train loss: 2737.9435]\n", + "Saving current: saved/GRU4RecF-Mar-30-2022_00-50-40.pth\n", + "epoch 26 training [time: 44.25s, train loss: 2732.9869]\n", + "Saving current: saved/GRU4RecF-Mar-30-2022_00-50-40.pth\n", + "epoch 27 training [time: 44.63s, train loss: 2728.7535]\n", + "Saving current: saved/GRU4RecF-Mar-30-2022_00-50-40.pth\n", + "epoch 28 training [time: 44.28s, train loss: 2724.6929]\n", + "Saving current: saved/GRU4RecF-Mar-30-2022_00-50-40.pth\n", + "epoch 29 training [time: 44.38s, train loss: 2720.6401]\n", + "Saving current: saved/GRU4RecF-Mar-30-2022_00-50-40.pth\n" + ] + } + ], + "source": [ + "# model loading and initialization\n", + "model = GRU4RecF(config, train_data.dataset).to(config['device'])\n", + "logger.info(model)\n", + "\n", + "# trainer loading and initialization\n", + "trainer = Trainer(config, model)\n", + "\n", + "# model training\n", + "best_valid_score, best_valid_result = trainer.fit(train_data)" + ] + }, + { + "cell_type": "markdown", + "id": "f40b6bc6", + "metadata": { + "papermill": { + "duration": 0.347651, + "end_time": "2022-03-30T01:14:14.780591", + "exception": false, + "start_time": "2022-03-30T01:14:14.432940", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "# 4. Create recommendation result from trained model\n", + "\n", + "I note document here for any one want to customize it: https://recbole.io/docs/user_guide/usage/case_study.html" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "9cbfb308", + "metadata": { + "execution": { + "iopub.execute_input": "2022-03-30T01:14:15.484564Z", + "iopub.status.busy": "2022-03-30T01:14:15.483621Z", + "iopub.status.idle": "2022-03-30T01:14:15.492929Z", + "shell.execute_reply": "2022-03-30T01:14:15.492376Z", + "shell.execute_reply.started": "2022-03-20T04:00:34.568423Z" + }, + "papermill": { + "duration": 0.361155, + "end_time": "2022-03-30T01:14:15.493104", + "exception": false, + "start_time": "2022-03-30T01:14:15.131949", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "from recbole.utils.case_study import full_sort_topk\n", + "external_user_ids = dataset.id2token(\n", + " dataset.uid_field, list(range(dataset.user_num)))[1:]#fist element in array is 'PAD'(default of Recbole) ->remove it " + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "d7e54de4", + "metadata": { + "execution": { + "iopub.execute_input": "2022-03-30T01:14:16.564600Z", + "iopub.status.busy": "2022-03-30T01:14:16.563632Z", + "iopub.status.idle": "2022-03-30T01:14:16.567396Z", + "shell.execute_reply": "2022-03-30T01:14:16.567950Z", + "shell.execute_reply.started": "2022-03-20T04:00:34.570281Z" + }, + "papermill": { + "duration": 0.544809, + "end_time": "2022-03-30T01:14:16.568142", + "exception": false, + "start_time": "2022-03-30T01:14:16.023333", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "import torch\n", + "from recbole.data.interaction import Interaction\n", + "\n", + "def add_last_item(old_interaction, last_item_id, max_len=50):\n", + " new_seq_items = old_interaction['item_id_list'][-1]\n", + " if old_interaction['item_length'][-1].item() < max_len:\n", + " new_seq_items[old_interaction['item_length'][-1].item()] = last_item_id\n", + " else:\n", + " new_seq_items = torch.roll(new_seq_items, -1)\n", + " new_seq_items[-1] = last_item_id\n", + " return new_seq_items.view(1, len(new_seq_items))\n", + "\n", + "def predict_for_all_item(external_user_id, dataset, model):\n", + " model.eval()\n", + " with torch.no_grad():\n", + " uid_series = dataset.token2id(dataset.uid_field, [external_user_id])\n", + " index = np.isin(dataset.inter_feat[dataset.uid_field].numpy(), uid_series)\n", + " input_interaction = dataset[index]\n", + " test = {\n", + " 'item_id_list': add_last_item(input_interaction, \n", + " input_interaction['item_id'][-1].item(), model.max_seq_length),\n", + " 'item_length': torch.tensor(\n", + " [input_interaction['item_length'][-1].item() + 1\n", + " if input_interaction['item_length'][-1].item() < model.max_seq_length else model.max_seq_length])\n", + " }\n", + " new_inter = Interaction(test)\n", + " new_inter = new_inter.to(config['device'])\n", + " new_scores = model.full_sort_predict(new_inter)\n", + " new_scores = new_scores.view(-1, test_data.dataset.item_num)\n", + " new_scores[:, 0] = -np.inf # set scores of [pad] to -inf\n", + " return torch.topk(new_scores, 12)" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "e6c29c5a", + "metadata": { + "execution": { + "iopub.execute_input": "2022-03-30T01:14:17.240185Z", + "iopub.status.busy": "2022-03-30T01:14:17.239246Z", + "iopub.status.idle": "2022-03-30T01:14:17.432694Z", + "shell.execute_reply": "2022-03-30T01:14:17.433385Z", + "shell.execute_reply.started": "2022-03-20T04:00:34.572226Z" + }, + "papermill": { + "duration": 0.523213, + "end_time": "2022-03-30T01:14:17.433563", + "exception": false, + "start_time": "2022-03-30T01:14:16.910350", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.return_types.topk(\n", + "values=tensor([[8.2494, 7.3775, 7.2640, 6.9034, 6.8279, 6.5702, 6.4665, 6.2845, 6.0385,\n", + " 6.0317, 5.9852, 5.9483]], device='cuda:0'),\n", + "indices=tensor([[4559, 1608, 5835, 4529, 1187, 5412, 371, 2589, 4579, 638, 2019, 2415]],\n", + " device='cuda:0'))" + ] + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "predict_for_all_item('0109ad0b5a76924a1b58be677409bb601cc8bead9a87b8ce5b08a4a1f5bc71ef', \n", + " dataset, model)" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "47107721", + "metadata": { + "execution": { + "iopub.execute_input": "2022-03-30T01:14:18.083485Z", + "iopub.status.busy": "2022-03-30T01:14:18.082537Z", + "iopub.status.idle": "2022-03-30T01:20:10.117399Z", + "shell.execute_reply": "2022-03-30T01:20:10.116413Z", + "shell.execute_reply.started": "2022-03-20T04:00:34.574183Z" + }, + "papermill": { + "duration": 352.362944, + "end_time": "2022-03-30T01:20:10.117614", + "exception": false, + "start_time": "2022-03-30T01:14:17.754670", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "15458\n" + ] + } + ], + "source": [ + "topk_items = []\n", + "for external_user_id in external_user_ids:\n", + " _, topk_iid_list = predict_for_all_item(external_user_id, dataset, model)\n", + " last_topk_iid_list = topk_iid_list[-1]\n", + " external_item_list = dataset.id2token(dataset.iid_field, last_topk_iid_list.cpu()).tolist()\n", + " topk_items.append(external_item_list)\n", + "print(len(topk_items))" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "4c446bd6", + "metadata": { + "execution": { + "iopub.execute_input": "2022-03-30T01:20:10.838441Z", + "iopub.status.busy": "2022-03-30T01:20:10.837277Z", + "iopub.status.idle": "2022-03-30T01:20:10.854113Z", + "shell.execute_reply": "2022-03-30T01:20:10.853536Z", + "shell.execute_reply.started": "2022-03-20T04:00:34.575588Z" + }, + "papermill": { + "duration": 0.357176, + "end_time": "2022-03-30T01:20:10.854291", + "exception": false, + "start_time": "2022-03-30T01:20:10.497115", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
customer_idprediction
000d7ebd46f6a6d53630d41386b6ef6a505cdc4c80011ff...0918522001 0910601003 0673677022 0910601002 08...
10109ad0b5a76924a1b58be677409bb601cc8bead9a87b8...0901955001 0833530002 0913030001 0861477001 08...
2013f00f9e218549246a3aa82b3f3a0c22a693bc25fa735...0839402002 0865172003 0839402001 0770336001 08...
301bada2a453b09c70ea57bdda5a9ef0fb04062718d3a3d...0914441004 0724906006 0868874006 0867966009 07...
401dd96059a11759518f10969d0a528f03c8501dc4c628b...0891663002 0850244002 0817353008 0891663001 08...
\n", + "
" + ], + "text/plain": [ + " customer_id \\\n", + "0 00d7ebd46f6a6d53630d41386b6ef6a505cdc4c80011ff... \n", + "1 0109ad0b5a76924a1b58be677409bb601cc8bead9a87b8... \n", + "2 013f00f9e218549246a3aa82b3f3a0c22a693bc25fa735... \n", + "3 01bada2a453b09c70ea57bdda5a9ef0fb04062718d3a3d... \n", + "4 01dd96059a11759518f10969d0a528f03c8501dc4c628b... \n", + "\n", + " prediction \n", + "0 0918522001 0910601003 0673677022 0910601002 08... \n", + "1 0901955001 0833530002 0913030001 0861477001 08... \n", + "2 0839402002 0865172003 0839402001 0770336001 08... \n", + "3 0914441004 0724906006 0868874006 0867966009 07... \n", + "4 0891663002 0850244002 0817353008 0891663001 08... " + ] + }, + "execution_count": 27, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "external_item_str = [' '.join(x) for x in topk_items]\n", + "result = pd.DataFrame(external_user_ids, columns=['customer_id'])\n", + "result['prediction'] = external_item_str\n", + "result.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "76fed207", + "metadata": { + "execution": { + "iopub.execute_input": "2022-03-30T01:20:11.742604Z", + "iopub.status.busy": "2022-03-30T01:20:11.549113Z", + "iopub.status.idle": "2022-03-30T01:20:11.745644Z", + "shell.execute_reply": "2022-03-30T01:20:11.746258Z", + "shell.execute_reply.started": "2022-03-20T04:00:34.577985Z" + }, + "papermill": { + "duration": 0.563057, + "end_time": "2022-03-30T01:20:11.746435", + "exception": false, + "start_time": "2022-03-30T01:20:11.183378", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "42" + ] + }, + "execution_count": 28, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "del external_item_str\n", + "del topk_items\n", + "del external_user_ids\n", + "del train_data\n", + "del valid_data\n", + "del test_data\n", + "del model\n", + "del Trainer\n", + "del logger\n", + "del dataset\n", + "gc.collect()" + ] + }, + { + "cell_type": "markdown", + "id": "a9bdc27a", + "metadata": { + "papermill": { + "duration": 0.32871, + "end_time": "2022-03-30T01:20:12.399473", + "exception": false, + "start_time": "2022-03-30T01:20:12.070763", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "# 5. Combine result from most bought items and GRU model" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "id": "ac55b63e", + "metadata": { + "execution": { + "iopub.execute_input": "2022-03-30T01:20:13.063277Z", + "iopub.status.busy": "2022-03-30T01:20:13.062460Z", + "iopub.status.idle": "2022-03-30T01:20:16.823768Z", + "shell.execute_reply": "2022-03-30T01:20:16.823141Z", + "shell.execute_reply.started": "2022-03-20T04:00:34.579843Z" + }, + "papermill": { + "duration": 4.096581, + "end_time": "2022-03-30T01:20:16.823920", + "exception": false, + "start_time": "2022-03-30T01:20:12.727339", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "(1371980, 2)" + ] + }, + "execution_count": 29, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "submit_df = pd.read_csv('submission.csv')\n", + "submit_df.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "id": "c88aad11", + "metadata": { + "execution": { + "iopub.execute_input": "2022-03-30T01:20:17.491328Z", + "iopub.status.busy": "2022-03-30T01:20:17.489060Z", + "iopub.status.idle": "2022-03-30T01:20:17.495042Z", + "shell.execute_reply": "2022-03-30T01:20:17.495635Z", + "shell.execute_reply.started": "2022-03-20T04:00:34.581657Z" + }, + "papermill": { + "duration": 0.346787, + "end_time": "2022-03-30T01:20:17.495829", + "exception": false, + "start_time": "2022-03-30T01:20:17.149042", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
customer_idprediction
000000dbacae5abe5e23885899a1fa44253a17956c6d1c3...0568601043 0568601006 0656719005 0745232001 09...
10000423b00ade91418cceaf3b26c6af3dd342b51fd051e...0826211002 0800436010 0924243001 0739590027 07...
2000058a12d5b43e67d225668fa1f8d618c13dc232df0ca...0794321007 0852643001 0852643003 0858883002 09...
300005ca1c9ed5f5146b52ac8639a40ca9d57aeff4d1bd2...0448509014 0573085028 0924243001 0751471001 07...
400006413d8573cd20ed7128e53b7b13819fe5cfc2d801f...0730683050 0791587015 0924243001 0896152002 08...
\n", + "
" + ], + "text/plain": [ + " customer_id \\\n", + "0 00000dbacae5abe5e23885899a1fa44253a17956c6d1c3... \n", + "1 0000423b00ade91418cceaf3b26c6af3dd342b51fd051e... \n", + "2 000058a12d5b43e67d225668fa1f8d618c13dc232df0ca... \n", + "3 00005ca1c9ed5f5146b52ac8639a40ca9d57aeff4d1bd2... \n", + "4 00006413d8573cd20ed7128e53b7b13819fe5cfc2d801f... \n", + "\n", + " prediction \n", + "0 0568601043 0568601006 0656719005 0745232001 09... \n", + "1 0826211002 0800436010 0924243001 0739590027 07... \n", + "2 0794321007 0852643001 0852643003 0858883002 09... \n", + "3 0448509014 0573085028 0924243001 0751471001 07... \n", + "4 0730683050 0791587015 0924243001 0896152002 08... " + ] + }, + "execution_count": 30, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "submit_df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "id": "88c4db15", + "metadata": { + "execution": { + "iopub.execute_input": "2022-03-30T01:20:18.936665Z", + "iopub.status.busy": "2022-03-30T01:20:18.935582Z", + "iopub.status.idle": "2022-03-30T01:20:19.413534Z", + "shell.execute_reply": "2022-03-30T01:20:19.414149Z", + "shell.execute_reply.started": "2022-03-20T04:00:34.583483Z" + }, + "papermill": { + "duration": 1.582283, + "end_time": "2022-03-30T01:20:19.414370", + "exception": false, + "start_time": "2022-03-30T01:20:17.832087", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
customer_idprediction_xprediction_y
000000dbacae5abe5e23885899a1fa44253a17956c6d1c3...0568601043 0568601006 0656719005 0745232001 09...NaN
10000423b00ade91418cceaf3b26c6af3dd342b51fd051e...0826211002 0800436010 0924243001 0739590027 07...NaN
2000058a12d5b43e67d225668fa1f8d618c13dc232df0ca...0794321007 0852643001 0852643003 0858883002 09...NaN
300005ca1c9ed5f5146b52ac8639a40ca9d57aeff4d1bd2...0448509014 0573085028 0924243001 0751471001 07...NaN
400006413d8573cd20ed7128e53b7b13819fe5cfc2d801f...0730683050 0791587015 0924243001 0896152002 08...NaN
\n", + "
" + ], + "text/plain": [ + " customer_id \\\n", + "0 00000dbacae5abe5e23885899a1fa44253a17956c6d1c3... \n", + "1 0000423b00ade91418cceaf3b26c6af3dd342b51fd051e... \n", + "2 000058a12d5b43e67d225668fa1f8d618c13dc232df0ca... \n", + "3 00005ca1c9ed5f5146b52ac8639a40ca9d57aeff4d1bd2... \n", + "4 00006413d8573cd20ed7128e53b7b13819fe5cfc2d801f... \n", + "\n", + " prediction_x prediction_y \n", + "0 0568601043 0568601006 0656719005 0745232001 09... NaN \n", + "1 0826211002 0800436010 0924243001 0739590027 07... NaN \n", + "2 0794321007 0852643001 0852643003 0858883002 09... NaN \n", + "3 0448509014 0573085028 0924243001 0751471001 07... NaN \n", + "4 0730683050 0791587015 0924243001 0896152002 08... NaN " + ] + }, + "execution_count": 31, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "submit_df = pd.merge(submit_df, result, on='customer_id', how='outer')\n", + "submit_df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "id": "b9381c71", + "metadata": { + "execution": { + "iopub.execute_input": "2022-03-30T01:20:20.617905Z", + "iopub.status.busy": "2022-03-30T01:20:20.616834Z", + "iopub.status.idle": "2022-03-30T01:20:49.699158Z", + "shell.execute_reply": "2022-03-30T01:20:49.698558Z", + "shell.execute_reply.started": "2022-03-20T04:00:34.58531Z" + }, + "papermill": { + "duration": 29.957585, + "end_time": "2022-03-30T01:20:49.699340", + "exception": false, + "start_time": "2022-03-30T01:20:19.741755", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
customer_idprediction_xprediction_yprediction
000000dbacae5abe5e23885899a1fa44253a17956c6d1c3...0568601043 0568601006 0656719005 0745232001 09...-10568601043 0568601006 0656719005 0745232001 09...
10000423b00ade91418cceaf3b26c6af3dd342b51fd051e...0826211002 0800436010 0924243001 0739590027 07...-10826211002 0800436010 0924243001 0739590027 07...
2000058a12d5b43e67d225668fa1f8d618c13dc232df0ca...0794321007 0852643001 0852643003 0858883002 09...-10794321007 0852643001 0852643003 0858883002 09...
300005ca1c9ed5f5146b52ac8639a40ca9d57aeff4d1bd2...0448509014 0573085028 0924243001 0751471001 07...-10448509014 0573085028 0924243001 0751471001 07...
400006413d8573cd20ed7128e53b7b13819fe5cfc2d801f...0730683050 0791587015 0924243001 0896152002 08...-10730683050 0791587015 0924243001 0896152002 08...
\n", + "
" + ], + "text/plain": [ + " customer_id \\\n", + "0 00000dbacae5abe5e23885899a1fa44253a17956c6d1c3... \n", + "1 0000423b00ade91418cceaf3b26c6af3dd342b51fd051e... \n", + "2 000058a12d5b43e67d225668fa1f8d618c13dc232df0ca... \n", + "3 00005ca1c9ed5f5146b52ac8639a40ca9d57aeff4d1bd2... \n", + "4 00006413d8573cd20ed7128e53b7b13819fe5cfc2d801f... \n", + "\n", + " prediction_x prediction_y \\\n", + "0 0568601043 0568601006 0656719005 0745232001 09... -1 \n", + "1 0826211002 0800436010 0924243001 0739590027 07... -1 \n", + "2 0794321007 0852643001 0852643003 0858883002 09... -1 \n", + "3 0448509014 0573085028 0924243001 0751471001 07... -1 \n", + "4 0730683050 0791587015 0924243001 0896152002 08... -1 \n", + "\n", + " prediction \n", + "0 0568601043 0568601006 0656719005 0745232001 09... \n", + "1 0826211002 0800436010 0924243001 0739590027 07... \n", + "2 0794321007 0852643001 0852643003 0858883002 09... \n", + "3 0448509014 0573085028 0924243001 0751471001 07... \n", + "4 0730683050 0791587015 0924243001 0896152002 08... " + ] + }, + "execution_count": 32, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "submit_df = submit_df.fillna(-1)\n", + "submit_df['prediction'] = submit_df.apply(\n", + " lambda x: x['prediction_y'] if x['prediction_y'] != -1 else x['prediction_x'], axis=1)\n", + "submit_df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "id": "0b29f88a", + "metadata": { + "execution": { + "iopub.execute_input": "2022-03-30T01:20:50.360135Z", + "iopub.status.busy": "2022-03-30T01:20:50.358717Z", + "iopub.status.idle": "2022-03-30T01:20:50.708834Z", + "shell.execute_reply": "2022-03-30T01:20:50.709427Z", + "shell.execute_reply.started": "2022-03-20T04:00:34.587134Z" + }, + "papermill": { + "duration": 0.684384, + "end_time": "2022-03-30T01:20:50.709672", + "exception": false, + "start_time": "2022-03-30T01:20:50.025288", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
customer_idprediction
000000dbacae5abe5e23885899a1fa44253a17956c6d1c3...0568601043 0568601006 0656719005 0745232001 09...
10000423b00ade91418cceaf3b26c6af3dd342b51fd051e...0826211002 0800436010 0924243001 0739590027 07...
2000058a12d5b43e67d225668fa1f8d618c13dc232df0ca...0794321007 0852643001 0852643003 0858883002 09...
300005ca1c9ed5f5146b52ac8639a40ca9d57aeff4d1bd2...0448509014 0573085028 0924243001 0751471001 07...
400006413d8573cd20ed7128e53b7b13819fe5cfc2d801f...0730683050 0791587015 0924243001 0896152002 08...
\n", + "
" + ], + "text/plain": [ + " customer_id \\\n", + "0 00000dbacae5abe5e23885899a1fa44253a17956c6d1c3... \n", + "1 0000423b00ade91418cceaf3b26c6af3dd342b51fd051e... \n", + "2 000058a12d5b43e67d225668fa1f8d618c13dc232df0ca... \n", + "3 00005ca1c9ed5f5146b52ac8639a40ca9d57aeff4d1bd2... \n", + "4 00006413d8573cd20ed7128e53b7b13819fe5cfc2d801f... \n", + "\n", + " prediction \n", + "0 0568601043 0568601006 0656719005 0745232001 09... \n", + "1 0826211002 0800436010 0924243001 0739590027 07... \n", + "2 0794321007 0852643001 0852643003 0858883002 09... \n", + "3 0448509014 0573085028 0924243001 0751471001 07... \n", + "4 0730683050 0791587015 0924243001 0896152002 08... " + ] + }, + "execution_count": 33, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "submit_df = submit_df.drop(columns=['prediction_y', 'prediction_x'])\n", + "submit_df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "id": "0809a3da", + "metadata": { + "execution": { + "iopub.execute_input": "2022-03-30T01:20:51.371957Z", + "iopub.status.busy": "2022-03-30T01:20:51.370974Z", + "iopub.status.idle": "2022-03-30T01:21:02.714475Z", + "shell.execute_reply": "2022-03-30T01:21:02.716457Z", + "shell.execute_reply.started": "2022-03-20T04:00:34.588995Z" + }, + "papermill": { + "duration": 11.685594, + "end_time": "2022-03-30T01:21:02.716804", + "exception": false, + "start_time": "2022-03-30T01:20:51.031210", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "submit_df.to_csv('submission.csv', index=False)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "96f1ddb7", + "metadata": { + "papermill": { + "duration": 0.348025, + "end_time": "2022-03-30T01:21:03.656677", + "exception": false, + "start_time": "2022-03-30T01:21:03.308652", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.12" + }, + "papermill": { + "default_parameters": {}, + "duration": 2378.932098, + "end_time": "2022-03-30T01:21:06.903964", + "environment_variables": {}, + "exception": null, + "input_path": "__notebook__.ipynb", + "output_path": "__notebook__.ipynb", + "parameters": {}, + "start_time": "2022-03-30T00:41:27.971866", + "version": "2.3.3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/run_example/recbole-using-all-items-for-prediction.ipynb b/run_example/recbole-using-all-items-for-prediction.ipynb new file mode 100644 index 000000000..db9ce3d5f --- /dev/null +++ b/run_example/recbole-using-all-items-for-prediction.ipynb @@ -0,0 +1,1969 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "f0100cdc", + "metadata": { + "papermill": { + "duration": 0.051507, + "end_time": "2022-03-20T03:21:48.747089", + "exception": false, + "start_time": "2022-03-20T03:21:48.695582", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "**Edit**: \n", + "I have create new notebooks for applying our customize function for using all items as input for recommendation:\n", + "* Using only interactions: https://www.kaggle.com/astrung/sequential-model-fixed-missing-last-item\n", + "* Using interactions with item features: https://www.kaggle.com/code/astrung/lstm-model-with-item-infor-fix-missing-last-item\n", + "\n", + "- - -\n", + "\n", + "In my previous [notebook](https://www.kaggle.com/code/astrung/recbole-lstm-sequential-for-recomendation-tutorial) about sequential model with Recbole, someone asked me about the mechanism of test data when using `full_sort_topk` for prediction submitted recommendation in this [comment](https://www.kaggle.com/code/astrung/recbole-lstm-sequential-for-recomendation-tutorial/comments#1723707) and this [comment](https://www.kaggle.com/code/astrung/recbole-lstm-sequential-for-recomendation-tutorial/comments#1723707), and they have a doubt about whether we are using all items for getting final recommendation. Most of people has 2 questions about using `full_sort_topk` with test data:\n", + "1. Do items in test data are used as input features for getting recommendation ?\n", + "2. If test data is necessary for getting recommendation in Recbole API, how can we get recommendation without splitting into train/test data?\n", + "\n", + "In this notebook i will answer all questions:\n", + "1. Yes. In sequential models, items in test data is used as input features, but not last items. As a example, if user X have 3 items in test data(A, B, C) and 5 items in train data(a,b,c,d,e), test data will generate 3 sample rows for evaluating performance on user X:\n", + "* Row 1: Input features: `a,b,c,d,e,0,0`. Output features: `A`. `0` is a pad item\n", + "* Row 2: Input features: `a,b,c,d,e,A,0`. Output features: `B`.\n", + "* Row 3: Input features: `a,b,c,d,e,A,B`. Output features: `C`.\n", + "\n", + "In my previous notebook, i use last row result as recommendation, **so we still using nearly all of items as input for recommendation, except last item(item C)**. Our recommendation in previous notebooks may be not perfect, but it is simple as a tutorial for anyone want to start.\n", + "\n", + "**Note: This mechanism is only for sequential model in recbole. For other types of model, it isn't correct - it won't use items in test data for getting recommendation. If you have requests for explaining for other model, please upvote and comment. I will explain it in other notebook**\n", + "\n", + "In first session of this notebook, i will dig into test data to prove this conclusion.\n", + "\n", + "2. Yes, we can get recommendation by using all of items as input features, without splitting train/test. In order to do this, you need to modify recbole code:\n", + "* Fist, you copy last row in dataset(input features have all items, except last one), then add last item into input features.\n", + "* Then you predict directly from model api, without using [full_sort_score or full_sort_topk](https://recbole.io/docs/user_guide/usage/case_study.html)\n", + "\n", + "In second session of this notebook, i will show you how to do that.\n", + "\n", + "Ok, let start" + ] + }, + { + "cell_type": "markdown", + "id": "bbb84f80", + "metadata": { + "papermill": { + "duration": 0.092534, + "end_time": "2022-03-20T03:21:48.923117", + "exception": false, + "start_time": "2022-03-20T03:21:48.830583", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "# I. How test items are used in test data.\n", + "\n", + "For each item in test data, it will be generated as a sample row. As a example, if user X have 3 items in test data(A, B, C) and 5 items in train data(a,b,c,d,e), test data will generate 3 sample rows for evaluating performance on user X:\n", + "* Row 1: Input features: `a,b,c,d,e,0,0`. Output label: `A`. `0` is a pad item\n", + "* Row 2: Input features: `a,b,c,d,e,A,0`. Output label: `B`.\n", + "* Row 3: Input features: `a,b,c,d,e,A,B`. Output label: `C`.\n", + "\n", + "For proving it, we will create a dataset, then extract input features and label in test data.\n", + "\n", + "### 1. Let create test data in recbole" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "796894e4", + "metadata": { + "_cell_guid": "b1076dfc-b9ad-4769-8c92-a6c4dae69d19", + "_uuid": "8f2839f25d086af736a60e9eeb907d3b93b6e0e5", + "execution": { + "iopub.execute_input": "2022-03-20T03:21:49.077123Z", + "iopub.status.busy": "2022-03-20T03:21:49.076011Z", + "iopub.status.idle": "2022-03-20T03:22:10.110393Z", + "shell.execute_reply": "2022-03-20T03:22:10.109242Z", + "shell.execute_reply.started": "2022-03-19T07:15:02.071127Z" + }, + "papermill": { + "duration": 21.115617, + "end_time": "2022-03-20T03:22:10.110594", + "exception": false, + "start_time": "2022-03-20T03:21:48.994977", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Collecting recbole\r\n", + " Downloading recbole-1.0.1-py3-none-any.whl (2.0 MB)\r\n", + " |████████████████████████████████| 2.0 MB 537 kB/s \r\n", + "\u001b[?25hRequirement already satisfied: pandas>=1.0.5 in /opt/conda/lib/python3.7/site-packages (from recbole) (1.3.5)\r\n", + "Collecting colorlog==4.7.2\r\n", + " Downloading colorlog-4.7.2-py2.py3-none-any.whl (10 kB)\r\n", + "Collecting scipy==1.6.0\r\n", + " Downloading scipy-1.6.0-cp37-cp37m-manylinux1_x86_64.whl (27.4 MB)\r\n", + " |████████████████████████████████| 27.4 MB 99 kB/s \r\n", + "\u001b[?25hRequirement already satisfied: torch>=1.7.0 in /opt/conda/lib/python3.7/site-packages (from recbole) (1.9.1)\r\n", + "Requirement already satisfied: colorama==0.4.4 in /opt/conda/lib/python3.7/site-packages (from recbole) (0.4.4)\r\n", + "Requirement already satisfied: numpy>=1.17.2 in /opt/conda/lib/python3.7/site-packages (from recbole) (1.20.3)\r\n", + "Requirement already satisfied: pyyaml>=5.1.0 in /opt/conda/lib/python3.7/site-packages (from recbole) (6.0)\r\n", + "Requirement already satisfied: tensorboard>=2.5.0 in /opt/conda/lib/python3.7/site-packages (from recbole) (2.6.0)\r\n", + "Requirement already satisfied: scikit-learn>=0.23.2 in /opt/conda/lib/python3.7/site-packages (from recbole) (1.0.1)\r\n", + "Requirement already satisfied: tqdm>=4.48.2 in /opt/conda/lib/python3.7/site-packages (from recbole) (4.62.3)\r\n", + "Requirement already satisfied: python-dateutil>=2.7.3 in /opt/conda/lib/python3.7/site-packages (from pandas>=1.0.5->recbole) (2.8.2)\r\n", + "Requirement already satisfied: pytz>=2017.3 in /opt/conda/lib/python3.7/site-packages (from pandas>=1.0.5->recbole) (2021.3)\r\n", + "Requirement already satisfied: threadpoolctl>=2.0.0 in /opt/conda/lib/python3.7/site-packages (from scikit-learn>=0.23.2->recbole) (3.0.0)\r\n", + "Requirement already satisfied: joblib>=0.11 in /opt/conda/lib/python3.7/site-packages (from scikit-learn>=0.23.2->recbole) (1.1.0)\r\n", + "Requirement already satisfied: tensorboard-plugin-wit>=1.6.0 in /opt/conda/lib/python3.7/site-packages (from tensorboard>=2.5.0->recbole) (1.8.0)\r\n", + "Requirement already satisfied: setuptools>=41.0.0 in /opt/conda/lib/python3.7/site-packages (from tensorboard>=2.5.0->recbole) (59.5.0)\r\n", + "Requirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in /opt/conda/lib/python3.7/site-packages (from tensorboard>=2.5.0->recbole) (0.4.6)\r\n", + "Requirement already satisfied: requests<3,>=2.21.0 in /opt/conda/lib/python3.7/site-packages (from tensorboard>=2.5.0->recbole) (2.26.0)\r\n", + "Requirement already satisfied: werkzeug>=0.11.15 in /opt/conda/lib/python3.7/site-packages (from tensorboard>=2.5.0->recbole) (2.0.2)\r\n", + "Requirement already satisfied: absl-py>=0.4 in /opt/conda/lib/python3.7/site-packages (from tensorboard>=2.5.0->recbole) (0.15.0)\r\n", + "Requirement already satisfied: tensorboard-data-server<0.7.0,>=0.6.0 in /opt/conda/lib/python3.7/site-packages (from tensorboard>=2.5.0->recbole) (0.6.1)\r\n", + "Requirement already satisfied: wheel>=0.26 in /opt/conda/lib/python3.7/site-packages (from tensorboard>=2.5.0->recbole) (0.37.0)\r\n", + "Requirement already satisfied: protobuf>=3.6.0 in /opt/conda/lib/python3.7/site-packages (from tensorboard>=2.5.0->recbole) (3.19.4)\r\n", + "Requirement already satisfied: google-auth<2,>=1.6.3 in /opt/conda/lib/python3.7/site-packages (from tensorboard>=2.5.0->recbole) (1.35.0)\r\n", + "Requirement already satisfied: grpcio>=1.24.3 in /opt/conda/lib/python3.7/site-packages (from tensorboard>=2.5.0->recbole) (1.43.0)\r\n", + "Requirement already satisfied: markdown>=2.6.8 in /opt/conda/lib/python3.7/site-packages (from tensorboard>=2.5.0->recbole) (3.3.6)\r\n", + "Requirement already satisfied: typing-extensions in /opt/conda/lib/python3.7/site-packages (from torch>=1.7.0->recbole) (4.1.1)\r\n", + "Requirement already satisfied: six in /opt/conda/lib/python3.7/site-packages (from absl-py>=0.4->tensorboard>=2.5.0->recbole) (1.16.0)\r\n", + "Requirement already satisfied: cachetools<5.0,>=2.0.0 in /opt/conda/lib/python3.7/site-packages (from google-auth<2,>=1.6.3->tensorboard>=2.5.0->recbole) (4.2.4)\r\n", + "Requirement already satisfied: pyasn1-modules>=0.2.1 in /opt/conda/lib/python3.7/site-packages (from google-auth<2,>=1.6.3->tensorboard>=2.5.0->recbole) (0.2.7)\r\n", + "Requirement already satisfied: rsa<5,>=3.1.4 in /opt/conda/lib/python3.7/site-packages (from google-auth<2,>=1.6.3->tensorboard>=2.5.0->recbole) (4.8)\r\n", + "Requirement already satisfied: requests-oauthlib>=0.7.0 in /opt/conda/lib/python3.7/site-packages (from google-auth-oauthlib<0.5,>=0.4.1->tensorboard>=2.5.0->recbole) (1.3.0)\r\n", + "Requirement already satisfied: importlib-metadata>=4.4 in /opt/conda/lib/python3.7/site-packages (from markdown>=2.6.8->tensorboard>=2.5.0->recbole) (4.11.3)\r\n", + "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /opt/conda/lib/python3.7/site-packages (from requests<3,>=2.21.0->tensorboard>=2.5.0->recbole) (1.26.7)\r\n", + "Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.7/site-packages (from requests<3,>=2.21.0->tensorboard>=2.5.0->recbole) (2021.10.8)\r\n", + "Requirement already satisfied: charset-normalizer~=2.0.0 in /opt/conda/lib/python3.7/site-packages (from requests<3,>=2.21.0->tensorboard>=2.5.0->recbole) (2.0.9)\r\n", + "Requirement already satisfied: idna<4,>=2.5 in /opt/conda/lib/python3.7/site-packages (from requests<3,>=2.21.0->tensorboard>=2.5.0->recbole) (3.1)\r\n", + "Requirement already satisfied: zipp>=0.5 in /opt/conda/lib/python3.7/site-packages (from importlib-metadata>=4.4->markdown>=2.6.8->tensorboard>=2.5.0->recbole) (3.6.0)\r\n", + "Requirement already satisfied: pyasn1<0.5.0,>=0.4.6 in /opt/conda/lib/python3.7/site-packages (from pyasn1-modules>=0.2.1->google-auth<2,>=1.6.3->tensorboard>=2.5.0->recbole) (0.4.8)\r\n", + "Requirement already satisfied: oauthlib>=3.0.0 in /opt/conda/lib/python3.7/site-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tensorboard>=2.5.0->recbole) (3.1.1)\r\n", + "Installing collected packages: scipy, colorlog, recbole\r\n", + " Attempting uninstall: scipy\r\n", + " Found existing installation: scipy 1.7.3\r\n", + " Uninstalling scipy-1.7.3:\r\n", + " Successfully uninstalled scipy-1.7.3\r\n", + " Attempting uninstall: colorlog\r\n", + " Found existing installation: colorlog 6.6.0\r\n", + " Uninstalling colorlog-6.6.0:\r\n", + " Successfully uninstalled colorlog-6.6.0\r\n", + "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\r\n", + "pymc3 3.11.5 requires scipy<1.8.0,>=1.7.3, but you have scipy 1.6.0 which is incompatible.\r\n", + "pdpbox 0.2.1 requires matplotlib==3.1.1, but you have matplotlib 3.5.1 which is incompatible.\r\n", + "featuretools 1.6.0 requires numpy>=1.21.0, but you have numpy 1.20.3 which is incompatible.\r\n", + "arviz 0.11.4 requires typing-extensions<4,>=3.7.4.3, but you have typing-extensions 4.1.1 which is incompatible.\u001b[0m\r\n", + "Successfully installed colorlog-4.7.2 recbole-1.0.1 scipy-1.6.0\r\n", + "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\r\n" + ] + } + ], + "source": [ + "!pip install recbole" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "73239bab", + "metadata": { + "execution": { + "iopub.execute_input": "2022-03-20T03:22:10.257515Z", + "iopub.status.busy": "2022-03-20T03:22:10.256377Z", + "iopub.status.idle": "2022-03-20T03:23:16.176085Z", + "shell.execute_reply": "2022-03-20T03:23:16.176596Z", + "shell.execute_reply.started": "2022-03-19T07:15:17.429911Z" + }, + "papermill": { + "duration": 65.997767, + "end_time": "2022-03-20T03:23:16.176777", + "exception": false, + "start_time": "2022-03-20T03:22:10.179010", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
t_datcustomer_idarticle_idpricesales_channel_id
02018-09-20000058a12d5b43e67d225668fa1f8d618c13dc232df0ca...06637130010.0508312
12018-09-20000058a12d5b43e67d225668fa1f8d618c13dc232df0ca...05415180230.0304922
22018-09-2000007d2de826758b65a93dd24ce629ed66842531df6699...05052210040.0152372
32018-09-2000007d2de826758b65a93dd24ce629ed66842531df6699...06856870030.0169322
42018-09-2000007d2de826758b65a93dd24ce629ed66842531df6699...06856870040.0169322
\n", + "
" + ], + "text/plain": [ + " t_dat customer_id article_id \\\n", + "0 2018-09-20 000058a12d5b43e67d225668fa1f8d618c13dc232df0ca... 0663713001 \n", + "1 2018-09-20 000058a12d5b43e67d225668fa1f8d618c13dc232df0ca... 0541518023 \n", + "2 2018-09-20 00007d2de826758b65a93dd24ce629ed66842531df6699... 0505221004 \n", + "3 2018-09-20 00007d2de826758b65a93dd24ce629ed66842531df6699... 0685687003 \n", + "4 2018-09-20 00007d2de826758b65a93dd24ce629ed66842531df6699... 0685687004 \n", + "\n", + " price sales_channel_id \n", + "0 0.050831 2 \n", + "1 0.030492 2 \n", + "2 0.015237 2 \n", + "3 0.016932 2 \n", + "4 0.016932 2 " + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import pandas as pd\n", + "df = pd.read_csv(r\"/kaggle/input/h-and-m-personalized-fashion-recommendations/transactions_train.csv\", \n", + " dtype={'article_id': 'str'})\n", + "df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "26fe44ff", + "metadata": { + "execution": { + "iopub.execute_input": "2022-03-20T03:23:18.242810Z", + "iopub.status.busy": "2022-03-20T03:23:18.241761Z", + "iopub.status.idle": "2022-03-20T03:23:23.253894Z", + "shell.execute_reply": "2022-03-20T03:23:23.254595Z", + "shell.execute_reply.started": "2022-03-19T07:16:16.934597Z" + }, + "papermill": { + "duration": 7.012762, + "end_time": "2022-03-20T03:23:23.254783", + "exception": false, + "start_time": "2022-03-20T03:23:16.242021", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
t_datcustomer_idarticle_idpricesales_channel_idtimestamp
02018-09-20000058a12d5b43e67d225668fa1f8d618c13dc232df0ca...06637130010.05083121537401600
12018-09-20000058a12d5b43e67d225668fa1f8d618c13dc232df0ca...05415180230.03049221537401600
22018-09-2000007d2de826758b65a93dd24ce629ed66842531df6699...05052210040.01523721537401600
32018-09-2000007d2de826758b65a93dd24ce629ed66842531df6699...06856870030.01693221537401600
42018-09-2000007d2de826758b65a93dd24ce629ed66842531df6699...06856870040.01693221537401600
\n", + "
" + ], + "text/plain": [ + " t_dat customer_id article_id \\\n", + "0 2018-09-20 000058a12d5b43e67d225668fa1f8d618c13dc232df0ca... 0663713001 \n", + "1 2018-09-20 000058a12d5b43e67d225668fa1f8d618c13dc232df0ca... 0541518023 \n", + "2 2018-09-20 00007d2de826758b65a93dd24ce629ed66842531df6699... 0505221004 \n", + "3 2018-09-20 00007d2de826758b65a93dd24ce629ed66842531df6699... 0685687003 \n", + "4 2018-09-20 00007d2de826758b65a93dd24ce629ed66842531df6699... 0685687004 \n", + "\n", + " price sales_channel_id timestamp \n", + "0 0.050831 2 1537401600 \n", + "1 0.030492 2 1537401600 \n", + "2 0.015237 2 1537401600 \n", + "3 0.016932 2 1537401600 \n", + "4 0.016932 2 1537401600 " + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import numpy as np\n", + "df['t_dat'] = pd.to_datetime(df['t_dat'], format=\"%Y-%m-%d\")\n", + "df['timestamp'] = df.t_dat.values.astype(np.int64) // 10 ** 9\n", + "df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "b4b9ffc7", + "metadata": { + "execution": { + "iopub.execute_input": "2022-03-20T03:23:23.406061Z", + "iopub.status.busy": "2022-03-20T03:23:23.404925Z", + "iopub.status.idle": "2022-03-20T03:23:25.058648Z", + "shell.execute_reply": "2022-03-20T03:23:25.059139Z", + "shell.execute_reply.started": "2022-03-19T07:16:23.242276Z" + }, + "papermill": { + "duration": 1.734735, + "end_time": "2022-03-20T03:23:25.059306", + "exception": false, + "start_time": "2022-03-20T03:23:23.324571", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
user_id:tokenitem_id:tokentimestamp:float
23934157000058a12d5b43e67d225668fa1f8d618c13dc232df0ca...07278080011585699200
23934158000058a12d5b43e67d225668fa1f8d618c13dc232df0ca...07278080071585699200
23934159000563485cbb7850b0a93c6606f89c5b961c6647d1bd48...05675320151585699200
23934160000563485cbb7850b0a93c6606f89c5b961c6647d1bd48...07061040091585699200
2393416100083cda041544b2fbb0e0d2905ad17da7cf1007526fb4...07835040041585699200
............
31788319fff2282977442e327b45d8c89afde25617d00124d0f999...09295110011600732800
31788320fff2282977442e327b45d8c89afde25617d00124d0f999...08913220041600732800
31788321fff380805474b287b05cb2a7507b9a013482f7dd0bce0e...09183250011600732800
31788322fff4d3a8b1f3b60af93e78c30a7cb4cf75edaf2590d3e5...08334590021600732800
31788323fffef3b6b73545df065b521e19f64bf6fe93bfd450ab20...08985730031600732800
\n", + "

7854167 rows × 3 columns

\n", + "
" + ], + "text/plain": [ + " user_id:token item_id:token \\\n", + "23934157 000058a12d5b43e67d225668fa1f8d618c13dc232df0ca... 0727808001 \n", + "23934158 000058a12d5b43e67d225668fa1f8d618c13dc232df0ca... 0727808007 \n", + "23934159 000563485cbb7850b0a93c6606f89c5b961c6647d1bd48... 0567532015 \n", + "23934160 000563485cbb7850b0a93c6606f89c5b961c6647d1bd48... 0706104009 \n", + "23934161 00083cda041544b2fbb0e0d2905ad17da7cf1007526fb4... 0783504004 \n", + "... ... ... \n", + "31788319 fff2282977442e327b45d8c89afde25617d00124d0f999... 0929511001 \n", + "31788320 fff2282977442e327b45d8c89afde25617d00124d0f999... 0891322004 \n", + "31788321 fff380805474b287b05cb2a7507b9a013482f7dd0bce0e... 0918325001 \n", + "31788322 fff4d3a8b1f3b60af93e78c30a7cb4cf75edaf2590d3e5... 0833459002 \n", + "31788323 fffef3b6b73545df065b521e19f64bf6fe93bfd450ab20... 0898573003 \n", + "\n", + " timestamp:float \n", + "23934157 1585699200 \n", + "23934158 1585699200 \n", + "23934159 1585699200 \n", + "23934160 1585699200 \n", + "23934161 1585699200 \n", + "... ... \n", + "31788319 1600732800 \n", + "31788320 1600732800 \n", + "31788321 1600732800 \n", + "31788322 1600732800 \n", + "31788323 1600732800 \n", + "\n", + "[7854167 rows x 3 columns]" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "temp = df[df['timestamp'] > 1585620000][['customer_id', 'article_id', 'timestamp']].rename(\n", + " columns={'customer_id': 'user_id:token', 'article_id': 'item_id:token', 'timestamp': 'timestamp:float'})\n", + "temp" + ] + }, + { + "cell_type": "markdown", + "id": "d0f27239", + "metadata": { + "papermill": { + "duration": 0.065277, + "end_time": "2022-03-20T03:23:25.191185", + "exception": false, + "start_time": "2022-03-20T03:23:25.125908", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "Create data file in recbole format" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "fb54432a", + "metadata": { + "execution": { + "iopub.execute_input": "2022-03-20T03:23:25.330985Z", + "iopub.status.busy": "2022-03-20T03:23:25.330083Z", + "iopub.status.idle": "2022-03-20T03:24:02.483165Z", + "shell.execute_reply": "2022-03-20T03:24:02.482632Z", + "shell.execute_reply.started": "2022-03-19T07:16:24.818521Z" + }, + "papermill": { + "duration": 37.226498, + "end_time": "2022-03-20T03:24:02.483347", + "exception": false, + "start_time": "2022-03-20T03:23:25.256849", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "!mkdir /kaggle/working/recbox_data\n", + "temp.to_csv('/kaggle/working/recbox_data/recbox_data.inter', index=False, sep='\\t')" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "cf1cf5ea", + "metadata": { + "execution": { + "iopub.execute_input": "2022-03-20T03:24:02.732875Z", + "iopub.status.busy": "2022-03-20T03:24:02.731793Z", + "iopub.status.idle": "2022-03-20T03:24:02.737943Z", + "shell.execute_reply": "2022-03-20T03:24:02.738447Z", + "shell.execute_reply.started": "2022-03-19T07:17:00.442243Z" + }, + "papermill": { + "duration": 0.188845, + "end_time": "2022-03-20T03:24:02.738620", + "exception": false, + "start_time": "2022-03-20T03:24:02.549775", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "21" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import gc\n", + "del temp\n", + "gc.collect()" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "2f9f96fd", + "metadata": { + "execution": { + "iopub.execute_input": "2022-03-20T03:24:02.880709Z", + "iopub.status.busy": "2022-03-20T03:24:02.879572Z", + "iopub.status.idle": "2022-03-20T03:24:05.595930Z", + "shell.execute_reply": "2022-03-20T03:24:05.596790Z", + "shell.execute_reply.started": "2022-03-19T07:17:00.552996Z" + }, + "papermill": { + "duration": 2.793206, + "end_time": "2022-03-20T03:24:05.597057", + "exception": false, + "start_time": "2022-03-20T03:24:02.803851", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "import logging\n", + "from logging import getLogger\n", + "from recbole.config import Config\n", + "from recbole.data import create_dataset, data_preparation\n", + "from recbole.model.sequential_recommender import GRU4Rec\n", + "from recbole.trainer import Trainer\n", + "from recbole.utils import init_seed, init_logger" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "03cc093c", + "metadata": { + "execution": { + "iopub.execute_input": "2022-03-20T03:24:05.843816Z", + "iopub.status.busy": "2022-03-20T03:24:05.842712Z", + "iopub.status.idle": "2022-03-20T03:24:06.028813Z", + "shell.execute_reply": "2022-03-20T03:24:06.030099Z", + "shell.execute_reply.started": "2022-03-19T07:17:02.983278Z" + }, + "papermill": { + "duration": 0.322194, + "end_time": "2022-03-20T03:24:06.030375", + "exception": false, + "start_time": "2022-03-20T03:24:05.708181", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "parameter_dict = {\n", + " 'data_path': '/kaggle/working',\n", + " 'USER_ID_FIELD': 'user_id',\n", + " 'ITEM_ID_FIELD': 'item_id',\n", + " 'TIME_FIELD': 'timestamp',\n", + " 'user_inter_num_interval': \"[40,inf)\",\n", + " 'item_inter_num_interval': \"[40,inf)\",\n", + " 'load_col': {'inter': ['user_id', 'item_id', 'timestamp']},\n", + " 'neg_sampling': None,\n", + " 'epochs': 2,\n", + " 'eval_args': {\n", + " 'split': {'RS': [9, 0, 1]},\n", + " 'group_by': 'user',\n", + " 'order': 'TO',\n", + " 'mode': 'full'}\n", + "}\n", + "config = Config(model='GRU4Rec', dataset='recbox_data', config_dict=parameter_dict)\n", + "\n", + "# init random seed\n", + "init_seed(config['seed'], config['reproducibility'])\n", + "\n", + "# logger initialization\n", + "init_logger(config)\n", + "logger = getLogger()\n", + "# Create handlers\n", + "c_handler = logging.StreamHandler()\n", + "c_handler.setLevel(logging.INFO)\n", + "logger.addHandler(c_handler)\n", + "\n", + "# write config info into log\n", + "# logger.info(config)" + ] + }, + { + "cell_type": "markdown", + "id": "c1127367", + "metadata": { + "papermill": { + "duration": 0.072153, + "end_time": "2022-03-20T03:24:06.208954", + "exception": false, + "start_time": "2022-03-20T03:24:06.136801", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "Now let start spliting train data and test data in recbole" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "e7b1cfd3", + "metadata": { + "execution": { + "iopub.execute_input": "2022-03-20T03:24:06.355371Z", + "iopub.status.busy": "2022-03-20T03:24:06.354247Z", + "iopub.status.idle": "2022-03-20T03:25:34.154014Z", + "shell.execute_reply": "2022-03-20T03:25:34.135949Z", + "shell.execute_reply.started": "2022-03-19T07:17:03.140307Z" + }, + "papermill": { + "duration": 87.875574, + "end_time": "2022-03-20T03:25:34.154168", + "exception": false, + "start_time": "2022-03-20T03:24:06.278594", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "recbox_data\n", + "The number of users: 15459\n", + "Average actions of users: 59.21956268598784\n", + "The number of items: 7330\n", + "Average actions of items: 124.9032610178742\n", + "The number of inters: 915416\n", + "The sparsity of the dataset: 99.19214553975321%\n", + "Remain Fields: ['user_id', 'item_id', 'timestamp']\n" + ] + } + ], + "source": [ + "dataset = create_dataset(config)\n", + "logger.info(dataset)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "a317f077", + "metadata": { + "execution": { + "iopub.execute_input": "2022-03-20T03:25:34.315274Z", + "iopub.status.busy": "2022-03-20T03:25:34.314599Z", + "iopub.status.idle": "2022-03-20T03:25:56.109208Z", + "shell.execute_reply": "2022-03-20T03:25:56.101020Z", + "shell.execute_reply.started": "2022-03-19T07:18:19.704645Z" + }, + "papermill": { + "duration": 21.880677, + "end_time": "2022-03-20T03:25:56.109483", + "exception": false, + "start_time": "2022-03-20T03:25:34.228806", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[Training]: train_batch_size = [2048] negative sampling: [None]\n", + "[Evaluation]: eval_batch_size = [4096] eval_args: [{'split': {'RS': [9, 0, 1]}, 'group_by': 'user', 'order': 'TO', 'mode': 'full'}]\n" + ] + } + ], + "source": [ + "# dataset splitting\n", + "train_data, valid_data, test_data = data_preparation(config, dataset)" + ] + }, + { + "cell_type": "markdown", + "id": "3a74f096", + "metadata": { + "papermill": { + "duration": 0.082704, + "end_time": "2022-03-20T03:25:56.277447", + "exception": false, + "start_time": "2022-03-20T03:25:56.194743", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "### 2. Let extract sample rows from test data.\n", + "\n", + "We will check items of user `0109ad0b5a76924a1b58be677409bb601cc8bead9a87b8ce5b08a4a1f5bc71ef`. \n", + "\n", + "We except that last items of this user will be used as label in test data\n" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "410c4c02", + "metadata": { + "execution": { + "iopub.execute_input": "2022-03-20T03:26:01.061517Z", + "iopub.status.busy": "2022-03-20T03:26:01.060299Z", + "iopub.status.idle": "2022-03-20T03:26:05.674214Z", + "shell.execute_reply": "2022-03-20T03:26:05.673707Z", + "shell.execute_reply.started": "2022-03-19T07:18:39.031527Z" + }, + "papermill": { + "duration": 9.312405, + "end_time": "2022-03-20T03:26:05.674364", + "exception": false, + "start_time": "2022-03-20T03:25:56.361959", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
t_datcustomer_idarticle_idpricesales_channel_idtimestamp
293152292020-07-220109ad0b5a76924a1b58be677409bb601cc8bead9a87b8...08427270010.02710221595376000
293152302020-07-220109ad0b5a76924a1b58be677409bb601cc8bead9a87b8...08502410050.02032221595376000
293152312020-07-220109ad0b5a76924a1b58be677409bb601cc8bead9a87b8...08715170130.02032221595376000
304210032020-08-170109ad0b5a76924a1b58be677409bb601cc8bead9a87b8...09030620010.04088121597622400
304210042020-08-170109ad0b5a76924a1b58be677409bb601cc8bead9a87b8...08614780030.02452521597622400
304210052020-08-170109ad0b5a76924a1b58be677409bb601cc8bead9a87b8...08577780110.03269521597622400
304210062020-08-170109ad0b5a76924a1b58be677409bb601cc8bead9a87b8...06982860040.02452521597622400
304210072020-08-170109ad0b5a76924a1b58be677409bb601cc8bead9a87b8...06982860040.02452521597622400
304210082020-08-170109ad0b5a76924a1b58be677409bb601cc8bead9a87b8...08614780020.01633921597622400
304210092020-08-170109ad0b5a76924a1b58be677409bb601cc8bead9a87b8...09019550010.02452521597622400
\n", + "
" + ], + "text/plain": [ + " t_dat customer_id \\\n", + "29315229 2020-07-22 0109ad0b5a76924a1b58be677409bb601cc8bead9a87b8... \n", + "29315230 2020-07-22 0109ad0b5a76924a1b58be677409bb601cc8bead9a87b8... \n", + "29315231 2020-07-22 0109ad0b5a76924a1b58be677409bb601cc8bead9a87b8... \n", + "30421003 2020-08-17 0109ad0b5a76924a1b58be677409bb601cc8bead9a87b8... \n", + "30421004 2020-08-17 0109ad0b5a76924a1b58be677409bb601cc8bead9a87b8... \n", + "30421005 2020-08-17 0109ad0b5a76924a1b58be677409bb601cc8bead9a87b8... \n", + "30421006 2020-08-17 0109ad0b5a76924a1b58be677409bb601cc8bead9a87b8... \n", + "30421007 2020-08-17 0109ad0b5a76924a1b58be677409bb601cc8bead9a87b8... \n", + "30421008 2020-08-17 0109ad0b5a76924a1b58be677409bb601cc8bead9a87b8... \n", + "30421009 2020-08-17 0109ad0b5a76924a1b58be677409bb601cc8bead9a87b8... \n", + "\n", + " article_id price sales_channel_id timestamp \n", + "29315229 0842727001 0.027102 2 1595376000 \n", + "29315230 0850241005 0.020322 2 1595376000 \n", + "29315231 0871517013 0.020322 2 1595376000 \n", + "30421003 0903062001 0.040881 2 1597622400 \n", + "30421004 0861478003 0.024525 2 1597622400 \n", + "30421005 0857778011 0.032695 2 1597622400 \n", + "30421006 0698286004 0.024525 2 1597622400 \n", + "30421007 0698286004 0.024525 2 1597622400 \n", + "30421008 0861478002 0.016339 2 1597622400 \n", + "30421009 0901955001 0.024525 2 1597622400 " + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "last_item_ids = df[df.customer_id == '0109ad0b5a76924a1b58be677409bb601cc8bead9a87b8ce5b08a4a1f5bc71ef'\n", + " ].tail(10).article_id.values\n", + "df[df.customer_id == '0109ad0b5a76924a1b58be677409bb601cc8bead9a87b8ce5b08a4a1f5bc71ef'].tail(10)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "1dfa904a", + "metadata": { + "execution": { + "iopub.execute_input": "2022-03-20T03:26:05.845512Z", + "iopub.status.busy": "2022-03-20T03:26:05.844574Z", + "iopub.status.idle": "2022-03-20T03:26:05.848469Z", + "shell.execute_reply": "2022-03-20T03:26:05.848986Z", + "shell.execute_reply.started": "2022-03-19T08:01:57.513098Z" + }, + "papermill": { + "duration": 0.092529, + "end_time": "2022-03-20T03:26:05.849142", + "exception": false, + "start_time": "2022-03-20T03:26:05.756613", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "array(['0842727001', '0850241005', '0871517013', '0903062001',\n", + " '0861478003', '0857778011', '0698286004', '0698286004',\n", + " '0861478002', '0901955001'], dtype=object)" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "last_item_ids" + ] + }, + { + "cell_type": "markdown", + "id": "e61e4796", + "metadata": { + "papermill": { + "duration": 0.081369, + "end_time": "2022-03-20T03:26:06.014083", + "exception": false, + "start_time": "2022-03-20T03:26:05.932714", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "Recbole use an internal ids for identify user_id and item_id, so let convert this user_id and his items into internal ids.\n", + "* customer_id: `0109ad0b5a76924a1b58be677409bb601cc8bead9a87b8ce5b08a4a1f5bc71ef` -> internal user id: 2\n", + "* last bought item_id: [..., '0698286004', '0861478002', '0901955001'] -> internal item id: [..., 3237, 4377, 4559]" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "56e3d2ff", + "metadata": { + "execution": { + "iopub.execute_input": "2022-03-20T03:26:06.185713Z", + "iopub.status.busy": "2022-03-20T03:26:06.184790Z", + "iopub.status.idle": "2022-03-20T03:26:06.188503Z", + "shell.execute_reply": "2022-03-20T03:26:06.189097Z", + "shell.execute_reply.started": "2022-03-19T07:56:12.42476Z" + }, + "papermill": { + "duration": 0.093653, + "end_time": "2022-03-20T03:26:06.189274", + "exception": false, + "start_time": "2022-03-20T03:26:06.095621", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "2" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "test_data.dataset.token2id(test_data.dataset.uid_field, \n", + " '0109ad0b5a76924a1b58be677409bb601cc8bead9a87b8ce5b08a4a1f5bc71ef')" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "5cb25ea8", + "metadata": { + "execution": { + "iopub.execute_input": "2022-03-20T03:26:06.365496Z", + "iopub.status.busy": "2022-03-20T03:26:06.364116Z", + "iopub.status.idle": "2022-03-20T03:26:06.370226Z", + "shell.execute_reply": "2022-03-20T03:26:06.371133Z", + "shell.execute_reply.started": "2022-03-19T08:02:33.042716Z" + }, + "papermill": { + "duration": 0.097778, + "end_time": "2022-03-20T03:26:06.371428", + "exception": false, + "start_time": "2022-03-20T03:26:06.273650", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[ 797 4975 6339 6070 4054 6745 3237 3237 4377 4559]\n" + ] + } + ], + "source": [ + "print(dataset.token2id(dataset.iid_field, last_item_ids))" + ] + }, + { + "cell_type": "markdown", + "id": "dfa2bf97", + "metadata": { + "papermill": { + "duration": 0.0822, + "end_time": "2022-03-20T03:26:06.539988", + "exception": false, + "start_time": "2022-03-20T03:26:06.457788", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "**Now let extract input features and labels in our test data.\n", + "My extracted code is copy from [this source](https://recbole.io/docs/_modules/recbole/utils/case_study.html#full_sort_scores)**" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "6585fd12", + "metadata": { + "execution": { + "iopub.execute_input": "2022-03-20T03:26:06.714700Z", + "iopub.status.busy": "2022-03-20T03:26:06.714009Z", + "iopub.status.idle": "2022-03-20T03:26:06.729048Z", + "shell.execute_reply": "2022-03-20T03:26:06.729986Z", + "shell.execute_reply.started": "2022-03-19T07:18:48.212806Z" + }, + "papermill": { + "duration": 0.105247, + "end_time": "2022-03-20T03:26:06.730346", + "exception": false, + "start_time": "2022-03-20T03:26:06.625099", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "The batch_size of interaction: 5\n", + " user_id, torch.Size([5]), cpu, torch.int64\n", + " item_id, torch.Size([5]), cpu, torch.int64\n", + " timestamp, torch.Size([5]), cpu, torch.float32\n", + " item_length, torch.Size([5]), cpu, torch.int64\n", + " item_id_list, torch.Size([5, 50]), cpu, torch.int64\n", + " timestamp_list, torch.Size([5, 50]), cpu, torch.float32\n" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "input_features = test_data.dataset[np.isin(test_data.dataset[test_data.dataset.uid_field].numpy(), [2])]\n", + "input_features" + ] + }, + { + "cell_type": "markdown", + "id": "352b35aa", + "metadata": { + "papermill": { + "duration": 0.099346, + "end_time": "2022-03-20T03:26:06.936217", + "exception": false, + "start_time": "2022-03-20T03:26:06.836871", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "* **item_id in above interaction is used as label item**\n", + "* **item_id_list in above interaction is used as feature items**\n", + "\n", + "Let check it" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "c45ec082", + "metadata": { + "execution": { + "iopub.execute_input": "2022-03-20T03:26:07.113806Z", + "iopub.status.busy": "2022-03-20T03:26:07.113012Z", + "iopub.status.idle": "2022-03-20T03:26:07.130415Z", + "shell.execute_reply": "2022-03-20T03:26:07.129743Z", + "shell.execute_reply.started": "2022-03-19T08:09:07.459035Z" + }, + "papermill": { + "duration": 0.109152, + "end_time": "2022-03-20T03:26:07.130579", + "exception": false, + "start_time": "2022-03-20T03:26:07.021427", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "test label: tensor([6745, 3237, 3237, 4377, 4559])\n", + "last 10 items from origin dataset: [ 797 4975 6339 6070 4054 6745 3237 3237 4377 4559]\n" + ] + } + ], + "source": [ + "print(\"test label: \" + str(input_features['item_id']))\n", + "print(\"last 10 items from origin dataset: \" + str(dataset.token2id(dataset.iid_field, last_item_ids)))" + ] + }, + { + "cell_type": "markdown", + "id": "c9749297", + "metadata": { + "papermill": { + "duration": 0.088159, + "end_time": "2022-03-20T03:26:07.305671", + "exception": false, + "start_time": "2022-03-20T03:26:07.217512", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "As we expected, in last 10 items, 5 last items are used as label item. So for evaluating this user, we will have 5 sample rows in test data: \n", + "* Input feature: ? -> Output: 6745\n", + "* Input feature: ? -> Output: 3237\n", + "* Input feature: ? -> Output: 3237\n", + "* Input feature: ? -> Output: 4377\n", + "* Input feature: ? -> Output: 4559\n", + "\n", + "Now, let check input features in **item_id_list**" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "ae0203bf", + "metadata": { + "execution": { + "iopub.execute_input": "2022-03-20T03:26:07.490909Z", + "iopub.status.busy": "2022-03-20T03:26:07.489858Z", + "iopub.status.idle": "2022-03-20T03:26:07.497771Z", + "shell.execute_reply": "2022-03-20T03:26:07.497176Z", + "shell.execute_reply.started": "2022-03-19T07:18:48.243504Z" + }, + "papermill": { + "duration": 0.102151, + "end_time": "2022-03-20T03:26:07.497928", + "exception": false, + "start_time": "2022-03-20T03:26:07.395777", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[ 5, 6, 7, 8, 9, 10, 3566, 2149, 3123, 2739, 3548, 3545,\n", + " 3807, 4115, 949, 4476, 4398, 4109, 4449, 4449, 4024, 4570, 377, 5336,\n", + " 3491, 5340, 3608, 3608, 5138, 5138, 1880, 5234, 2442, 6000, 6000, 917,\n", + " 6092, 5555, 5555, 6257, 3216, 797, 4975, 6339, 6070, 4054, 0, 0,\n", + " 0, 0],\n", + " [ 5, 6, 7, 8, 9, 10, 3566, 2149, 3123, 2739, 3548, 3545,\n", + " 3807, 4115, 949, 4476, 4398, 4109, 4449, 4449, 4024, 4570, 377, 5336,\n", + " 3491, 5340, 3608, 3608, 5138, 5138, 1880, 5234, 2442, 6000, 6000, 917,\n", + " 6092, 5555, 5555, 6257, 3216, 797, 4975, 6339, 6070, 4054, 6745, 0,\n", + " 0, 0],\n", + " [ 5, 6, 7, 8, 9, 10, 3566, 2149, 3123, 2739, 3548, 3545,\n", + " 3807, 4115, 949, 4476, 4398, 4109, 4449, 4449, 4024, 4570, 377, 5336,\n", + " 3491, 5340, 3608, 3608, 5138, 5138, 1880, 5234, 2442, 6000, 6000, 917,\n", + " 6092, 5555, 5555, 6257, 3216, 797, 4975, 6339, 6070, 4054, 6745, 3237,\n", + " 0, 0],\n", + " [ 5, 6, 7, 8, 9, 10, 3566, 2149, 3123, 2739, 3548, 3545,\n", + " 3807, 4115, 949, 4476, 4398, 4109, 4449, 4449, 4024, 4570, 377, 5336,\n", + " 3491, 5340, 3608, 3608, 5138, 5138, 1880, 5234, 2442, 6000, 6000, 917,\n", + " 6092, 5555, 5555, 6257, 3216, 797, 4975, 6339, 6070, 4054, 6745, 3237,\n", + " 3237, 0],\n", + " [ 5, 6, 7, 8, 9, 10, 3566, 2149, 3123, 2739, 3548, 3545,\n", + " 3807, 4115, 949, 4476, 4398, 4109, 4449, 4449, 4024, 4570, 377, 5336,\n", + " 3491, 5340, 3608, 3608, 5138, 5138, 1880, 5234, 2442, 6000, 6000, 917,\n", + " 6092, 5555, 5555, 6257, 3216, 797, 4975, 6339, 6070, 4054, 6745, 3237,\n", + " 3237, 4377]])" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "input_features['item_id_list']" + ] + }, + { + "cell_type": "markdown", + "id": "3a4ebb13", + "metadata": { + "papermill": { + "duration": 0.087953, + "end_time": "2022-03-20T03:26:07.675590", + "exception": false, + "start_time": "2022-03-20T03:26:07.587637", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "We can see:\n", + "* For 1st row, it uses all items in training as input features.\n", + "* For 2nd row, it uses all items in training + first label as input features\n", + "* For 3rd row, it uses all items in training + first label + second label as input features\n", + "* ...\n", + "* For last row, it uses all items except last item as input features.\n", + "\n", + "In my previous notebooks([here](https://www.kaggle.com/code/astrung/lstm-sequential-modelwith-item-features-tutorial) and [here](https://www.kaggle.com/code/astrung/recbole-lstm-sequential-for-recomendation-tutorial/notebook)), **I use last row result for recommendation, so we are missing information from last item. **\n", + "\n", + "So now let fix it- find a new way for using all items" + ] + }, + { + "cell_type": "markdown", + "id": "0da61c25", + "metadata": { + "papermill": { + "duration": 0.089862, + "end_time": "2022-03-20T03:26:07.852504", + "exception": false, + "start_time": "2022-03-20T03:26:07.762642", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "# 2. Custom code for using all items in recommendation" + ] + }, + { + "cell_type": "markdown", + "id": "642300a0", + "metadata": { + "papermill": { + "duration": 0.08523, + "end_time": "2022-03-20T03:26:08.025711", + "exception": false, + "start_time": "2022-03-20T03:26:07.940481", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "We have seen that last row is missing only last item, so fixxing ideal is simple now:\n", + "* copy last row, add last item into it as a new interation(a row in test dataset)\n", + "* make prediction with new interation\n", + "\n", + "So now let train a dummy model for testing it\n", + "\n", + "### 1. Make dummy model" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "8e561976", + "metadata": { + "execution": { + "iopub.execute_input": "2022-03-20T03:26:08.204596Z", + "iopub.status.busy": "2022-03-20T03:26:08.200688Z", + "iopub.status.idle": "2022-03-20T03:26:42.476125Z", + "shell.execute_reply": "2022-03-20T03:26:42.467937Z", + "shell.execute_reply.started": "2022-03-19T07:18:48.254738Z" + }, + "papermill": { + "duration": 34.364819, + "end_time": "2022-03-20T03:26:42.476364", + "exception": false, + "start_time": "2022-03-20T03:26:08.111545", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "GRU4Rec(\n", + " (item_embedding): Embedding(7330, 64, padding_idx=0)\n", + " (emb_dropout): Dropout(p=0.3, inplace=False)\n", + " (gru_layers): GRU(64, 128, bias=False, batch_first=True)\n", + " (dense): Linear(in_features=128, out_features=64, bias=True)\n", + " (loss_fct): CrossEntropyLoss()\n", + ")\n", + "Trainable parameters: 551104\n", + "epoch 0 training [time: 14.43s, train loss: 3384.6924]\n", + "Saving current: saved/GRU4Rec-Mar-20-2022_03-26-16.pth\n", + "epoch 1 training [time: 10.94s, train loss: 3260.1609]\n", + "Saving current: saved/GRU4Rec-Mar-20-2022_03-26-16.pth\n" + ] + } + ], + "source": [ + "# model loading and initialization\n", + "model = GRU4Rec(config, train_data.dataset).to(config['device'])\n", + "logger.info(model)\n", + "\n", + "# trainer loading and initialization\n", + "trainer = Trainer(config, model)\n", + "\n", + "# model training\n", + "best_valid_score, best_valid_result = trainer.fit(train_data)" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "330ca134", + "metadata": { + "execution": { + "iopub.execute_input": "2022-03-20T03:26:42.672002Z", + "iopub.status.busy": "2022-03-20T03:26:42.671191Z", + "iopub.status.idle": "2022-03-20T03:26:42.674858Z", + "shell.execute_reply": "2022-03-20T03:26:42.675478Z", + "shell.execute_reply.started": "2022-03-19T07:19:18.347113Z" + }, + "papermill": { + "duration": 0.104305, + "end_time": "2022-03-20T03:26:42.675633", + "exception": false, + "start_time": "2022-03-20T03:26:42.571328", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "GRU4Rec(\n", + " (item_embedding): Embedding(7330, 64, padding_idx=0)\n", + " (emb_dropout): Dropout(p=0.3, inplace=False)\n", + " (gru_layers): GRU(64, 128, bias=False, batch_first=True)\n", + " (dense): Linear(in_features=128, out_features=64, bias=True)\n", + " (loss_fct): CrossEntropyLoss()\n", + ")" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model.eval()" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "02773d9a", + "metadata": { + "execution": { + "iopub.execute_input": "2022-03-20T03:26:42.870629Z", + "iopub.status.busy": "2022-03-20T03:26:42.869792Z", + "iopub.status.idle": "2022-03-20T03:26:42.873287Z", + "shell.execute_reply": "2022-03-20T03:26:42.873790Z", + "shell.execute_reply.started": "2022-03-19T07:19:18.355205Z" + }, + "papermill": { + "duration": 0.105847, + "end_time": "2022-03-20T03:26:42.873956", + "exception": false, + "start_time": "2022-03-20T03:26:42.768109", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([5, 50])" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "input_features['item_id_list'].shape" + ] + }, + { + "cell_type": "markdown", + "id": "f22ab4c5", + "metadata": { + "papermill": { + "duration": 0.094002, + "end_time": "2022-03-20T03:26:43.063520", + "exception": false, + "start_time": "2022-03-20T03:26:42.969518", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "Our sequence items is always have fix length(50). So if we have more than 50 items, we need to drop earlier items, and if there are less than 50 items, we need to add a padding(0) into input item features. As example:\n", + "* If last row input = [3, 4, 7,..., 20, 0, 0 ,0] (47 items < 50 item, so we have padding), after adding id=30, we will have input = [3, 4, 7,..., 20, 30, 0 ,0] Now our sequence lenght = 48 items.\n", + "* If last row input = [3, 4, 7,..., 20, 9, 10 ,12] (50 items), after adding id=30, we will have input = [4, 7,..., 9, 10, 12, 30] (drop first item and add last item).Now our sequence lenght still = 50 items.\n", + "\n", + "Now let implement it.\n", + "\n", + "First let extract last row from all interation when internal_user_id = 2 " + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "102bafb5", + "metadata": { + "execution": { + "iopub.execute_input": "2022-03-20T03:26:43.258540Z", + "iopub.status.busy": "2022-03-20T03:26:43.257542Z", + "iopub.status.idle": "2022-03-20T03:26:43.281126Z", + "shell.execute_reply": "2022-03-20T03:26:43.280610Z", + "shell.execute_reply.started": "2022-03-19T08:34:57.763854Z" + }, + "papermill": { + "duration": 0.122703, + "end_time": "2022-03-20T03:26:43.281269", + "exception": false, + "start_time": "2022-03-20T03:26:43.158566", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "The batch_size of interaction: 50\n", + " user_id, torch.Size([50]), cpu, torch.int64\n", + " item_id, torch.Size([50]), cpu, torch.int64\n", + " timestamp, torch.Size([50]), cpu, torch.float32\n", + " item_length, torch.Size([50]), cpu, torch.int64\n", + " item_id_list, torch.Size([50, 50]), cpu, torch.int64\n", + " timestamp_list, torch.Size([50, 50]), cpu, torch.float32\n" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "index = np.isin(dataset[dataset.uid_field].numpy(), [2])\n", + "input_interaction = dataset[index]\n", + "input_interaction" + ] + }, + { + "cell_type": "markdown", + "id": "0349a339", + "metadata": { + "papermill": { + "duration": 0.096329, + "end_time": "2022-03-20T03:26:43.472695", + "exception": false, + "start_time": "2022-03-20T03:26:43.376366", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "Now let add last item into sequences, and make new interaction.\n", + "We also need to edit sequence lenght (without padding)" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "76330a66", + "metadata": { + "execution": { + "iopub.execute_input": "2022-03-20T03:26:43.675855Z", + "iopub.status.busy": "2022-03-20T03:26:43.674868Z", + "iopub.status.idle": "2022-03-20T03:26:43.683634Z", + "shell.execute_reply": "2022-03-20T03:26:43.684193Z", + "shell.execute_reply.started": "2022-03-19T08:39:53.562108Z" + }, + "papermill": { + "duration": 0.115591, + "end_time": "2022-03-20T03:26:43.684381", + "exception": false, + "start_time": "2022-03-20T03:26:43.568790", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "The batch_size of interaction: 1\n", + " item_id_list, torch.Size([1, 50]), cpu, torch.int64\n", + " item_length, torch.Size([1]), cpu, torch.int64\n" + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import torch\n", + "from recbole.data.interaction import Interaction\n", + "\n", + "def add_last_item(old_interaction, last_item_id, max_len=50):\n", + " new_seq_items = old_interaction['item_id_list'][-1]\n", + " if old_interaction['item_length'][-1].item() < max_len:\n", + " new_seq_items[input_interaction['item_length'][-1].item()] = last_item_id\n", + " else:\n", + " new_seq_items = torch.roll(new_seq_items, -1)\n", + " new_seq_items[-1] = last_item_id\n", + " return new_seq_items.view(1, len(new_seq_items))\n", + "\n", + "test = {\n", + " 'item_id_list': add_last_item(input_interaction, input_interaction['item_id'][-1].item(), model.max_seq_length),\n", + " 'item_length': torch.tensor(\n", + " [input_interaction['item_length'][-1].item() + 1\n", + " if input_interaction['item_length'][-1].item() < model.max_seq_length else model.max_seq_length])\n", + " }\n", + "new_inter = Interaction(test)\n", + "new_inter" + ] + }, + { + "cell_type": "markdown", + "id": "61a9aaa8", + "metadata": { + "papermill": { + "duration": 0.097313, + "end_time": "2022-03-20T03:26:43.878524", + "exception": false, + "start_time": "2022-03-20T03:26:43.781211", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "Interaction for GRU4Rec model need to have only `item_id_list` and `item_lenght`. You can drop other key.\n", + "If you want more information, you can check [GRU4Rec code](https://recbole.io/docs/_modules/recbole/model/sequential_recommender/gru4rec.html#GRU4Rec)\n", + "\n", + "Then we can apply the remaining prediction code from [full_sort_scores](https://recbole.io/docs/_modules/recbole/utils/case_study.html#full_sort_scores)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "7da9030b", + "metadata": { + "execution": { + "iopub.execute_input": "2022-03-20T03:26:44.084630Z", + "iopub.status.busy": "2022-03-20T03:26:44.083600Z", + "iopub.status.idle": "2022-03-20T03:26:44.106014Z", + "shell.execute_reply": "2022-03-20T03:26:44.106586Z", + "shell.execute_reply.started": "2022-03-19T08:46:00.688378Z" + }, + "papermill": { + "duration": 0.129901, + "end_time": "2022-03-20T03:26:44.106769", + "exception": false, + "start_time": "2022-03-20T03:26:43.976868", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "new_inter = new_inter.to(config['device'])\n", + "new_scores = model.full_sort_predict(new_inter)\n", + "new_scores = new_scores.view(-1, test_data.dataset.item_num)\n", + "new_scores[:, 0] = -np.inf # set scores of [pad] to -inf" + ] + }, + { + "cell_type": "markdown", + "id": "9d7071fb", + "metadata": { + "papermill": { + "duration": 0.097247, + "end_time": "2022-03-20T03:26:44.301335", + "exception": false, + "start_time": "2022-03-20T03:26:44.204088", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "So now by combining all fragments,we have a new function for predicting with all item in dataset. You can use this custom code for all sequential model" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "28361728", + "metadata": { + "execution": { + "iopub.execute_input": "2022-03-20T03:26:44.517838Z", + "iopub.status.busy": "2022-03-20T03:26:44.516679Z", + "iopub.status.idle": "2022-03-20T03:26:44.519542Z", + "shell.execute_reply": "2022-03-20T03:26:44.520033Z", + "shell.execute_reply.started": "2022-03-19T08:50:43.786113Z" + }, + "papermill": { + "duration": 0.113123, + "end_time": "2022-03-20T03:26:44.520196", + "exception": false, + "start_time": "2022-03-20T03:26:44.407073", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "import torch\n", + "from recbole.data.interaction import Interaction\n", + "\n", + "def add_last_item(old_interaction, last_item_id, max_len=50):\n", + " new_seq_items = old_interaction['item_id_list'][-1]\n", + " if old_interaction['item_length'][-1].item() < max_len:\n", + " new_seq_items[old_interaction['item_length'][-1].item()] = last_item_id\n", + " else:\n", + " new_seq_items = torch.roll(new_seq_items, -1)\n", + " new_seq_items[-1] = last_item_id\n", + " return new_seq_items.view(1, len(new_seq_items))\n", + "\n", + "def predict_for_all_item(external_user_id, dataset, model):\n", + " model.eval()\n", + " with torch.no_grad():\n", + " uid_series = dataset.token2id(dataset.uid_field, [external_user_id])\n", + " index = np.isin(dataset[dataset.uid_field].numpy(), uid_series)\n", + " input_interaction = dataset[index]\n", + " test = {\n", + " 'item_id_list': add_last_item(input_interaction, \n", + " input_interaction['item_id'][-1].item(), model.max_seq_length),\n", + " 'item_length': torch.tensor(\n", + " [input_interaction['item_length'][-1].item() + 1\n", + " if input_interaction['item_length'][-1].item() < model.max_seq_length else model.max_seq_length])\n", + " }\n", + " new_inter = Interaction(test)\n", + " new_inter = new_inter.to(config['device'])\n", + " new_scores = model.full_sort_predict(new_inter)\n", + " new_scores = new_scores.view(-1, test_data.dataset.item_num)\n", + " new_scores[:, 0] = -np.inf # set scores of [pad] to -inf\n", + " return torch.topk(new_scores, 10)" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "07580e6e", + "metadata": { + "execution": { + "iopub.execute_input": "2022-03-20T03:26:44.719887Z", + "iopub.status.busy": "2022-03-20T03:26:44.718743Z", + "iopub.status.idle": "2022-03-20T03:26:44.789126Z", + "shell.execute_reply": "2022-03-20T03:26:44.789832Z", + "shell.execute_reply.started": "2022-03-19T08:49:01.592811Z" + }, + "papermill": { + "duration": 0.173425, + "end_time": "2022-03-20T03:26:44.790032", + "exception": false, + "start_time": "2022-03-20T03:26:44.616607", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.return_types.topk(\n", + "values=tensor([[5.4805, 5.2197, 5.1685, 5.1523, 5.0042, 4.9986, 4.9978, 4.9962, 4.9898,\n", + " 4.9557]], device='cuda:0'),\n", + "indices=tensor([[5412, 782, 2668, 1608, 4579, 2426, 294, 2427, 5835, 2669]],\n", + " device='cuda:0'))" + ] + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "predict_for_all_item('0109ad0b5a76924a1b58be677409bb601cc8bead9a87b8ce5b08a4a1f5bc71ef', \n", + " dataset, model) # we feed directly origin dataset, not train data or test data" + ] + }, + { + "cell_type": "markdown", + "id": "9c7564dc", + "metadata": { + "papermill": { + "duration": 0.098422, + "end_time": "2022-03-20T03:26:44.992841", + "exception": false, + "start_time": "2022-03-20T03:26:44.894419", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "Congratulation !!!.Now you can use all data as train set, don't need for a test set, but still can predict directly from dataset without testset.Now let apply it into our previous notebook.\n", + "\n", + "I have create new notebooks for applying our customize function for using all items as input for recommendation:\n", + "* Using only interactions: https://www.kaggle.com/astrung/sequential-model-fixed-missing-last-item\n", + "* Using interactions with item features: https://www.kaggle.com/code/astrung/lstm-model-with-item-infor-fix-missing-last-item\n", + "\n", + "Please check and upvote it" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "efefe8dc", + "metadata": { + "papermill": { + "duration": 0.098691, + "end_time": "2022-03-20T03:26:45.190857", + "exception": false, + "start_time": "2022-03-20T03:26:45.092166", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.12" + }, + "papermill": { + "default_parameters": {}, + "duration": 309.140629, + "end_time": "2022-03-20T03:26:48.101763", + "environment_variables": {}, + "exception": null, + "input_path": "__notebook__.ipynb", + "output_path": "__notebook__.ipynb", + "parameters": {}, + "start_time": "2022-03-20T03:21:38.961134", + "version": "2.3.3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/run_example/sequential-model-fixed-missing-last-item.ipynb b/run_example/sequential-model-fixed-missing-last-item.ipynb new file mode 100644 index 000000000..4cb98d3c8 --- /dev/null +++ b/run_example/sequential-model-fixed-missing-last-item.ipynb @@ -0,0 +1,2865 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "39c294a5", + "metadata": { + "papermill": { + "duration": 0.044353, + "end_time": "2022-03-20T02:19:14.265063", + "exception": false, + "start_time": "2022-03-20T02:19:14.220710", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "# 0.Overview\n", + "**Edit**:\n", + "* In my previous notebooks([here](https://www.kaggle.com/code/astrung/lstm-sequential-modelwith-item-features-tutorial) and [here](https://www.kaggle.com/code/astrung/lstm-sequential-modelwith-item-features-tutorial)), we have used test_data with `full_sort_topk`,but due to the limit of full_sort_topk we have missed last item for submited recommendation. Someone asked me about how can use all items as input features for recommendation in this [comment](https://www.kaggle.com/code/astrung/recbole-lstm-sequential-for-recomendation-tutorial/comments#1723707). \n", + "* So i created a notebook [here](https://www.kaggle.com/code/astrung/recbole-using-all-items-for-prediction) for address there questions in detail, and this notebook is an improved of my [previous notebook](https://www.kaggle.com/code/astrung/lstm-sequential-modelwith-item-features-tutorial), applying our new function (using all item as input features without `full_sort_topk`) for this competition.\n", + "* I also create a improved version for adding item features into model in this [notebook](https://www.kaggle.com/astrung/lstm-model-with-item-infor-fix-missing-last-item). It improved a little score when add item features for recommendation\n", + "\n", + "- - -\n", + "\n", + "\n", + "This notebook demonstrate how to use LSTM for recomendation system.\n", + "I am using Recbole as an open source, as it has so many built-in models for recommendation(CNN, GRU-LSTM, Context-aware, Graph). In this notebook, we tried to use GRU/LSTM model for testing effect of sequential model for recommendation.\n", + "\n", + "Due to memory limit and faster testing purpose, we will just use data in 2020.\n", + "\n", + "If you want to use with all of interactions in all time, i have created a new atomic dataset here for you: https://www.kaggle.com/astrung/hm-atomic-interation\n", + "\n", + "We also have other limit: we only train model and predict with users who buy more than 40 items and items which is bought by more than 40 people.\n", + "\n", + "We will follow below steps for creating model:\n", + "\n", + "1. In order to use Recbole, we create atomic file from interaction data\n", + "2. Because we only use Recbole model for predicting with users who buy more than 40 items, other users will need to fill by default recomendation items. We create most viewed items in last month as defautl recomendation\n", + "3. We create dataset and train model in recbole.\n", + "4. We create prediction result by trained model\n", + "5. We combine recomendation result from most viewed items in last month and Recbole predicted model.\n", + "\n", + "I will explain more detail in following cells.\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "88df5f58", + "metadata": { + "execution": { + "iopub.execute_input": "2022-03-20T02:19:14.361854Z", + "iopub.status.busy": "2022-03-20T02:19:14.360768Z", + "iopub.status.idle": "2022-03-20T02:19:35.871726Z", + "shell.execute_reply": "2022-03-20T02:19:35.870606Z", + "shell.execute_reply.started": "2022-03-19T09:03:25.544116Z" + }, + "papermill": { + "duration": 21.563461, + "end_time": "2022-03-20T02:19:35.872051", + "exception": false, + "start_time": "2022-03-20T02:19:14.308590", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Collecting recbole\r\n", + " Downloading recbole-1.0.1-py3-none-any.whl (2.0 MB)\r\n", + " |████████████████████████████████| 2.0 MB 618 kB/s \r\n", + "\u001b[?25hRequirement already satisfied: colorama==0.4.4 in /opt/conda/lib/python3.7/site-packages (from recbole) (0.4.4)\r\n", + "Requirement already satisfied: scikit-learn>=0.23.2 in /opt/conda/lib/python3.7/site-packages (from recbole) (0.23.2)\r\n", + "Requirement already satisfied: pyyaml>=5.1.0 in /opt/conda/lib/python3.7/site-packages (from recbole) (6.0)\r\n", + "Collecting colorlog==4.7.2\r\n", + " Downloading colorlog-4.7.2-py2.py3-none-any.whl (10 kB)\r\n", + "Requirement already satisfied: torch>=1.7.0 in /opt/conda/lib/python3.7/site-packages (from recbole) (1.9.1)\r\n", + "Collecting scipy==1.6.0\r\n", + " Downloading scipy-1.6.0-cp37-cp37m-manylinux1_x86_64.whl (27.4 MB)\r\n", + " |████████████████████████████████| 27.4 MB 125 kB/s \r\n", + "\u001b[?25hRequirement already satisfied: pandas>=1.0.5 in /opt/conda/lib/python3.7/site-packages (from recbole) (1.3.5)\r\n", + "Requirement already satisfied: tqdm>=4.48.2 in /opt/conda/lib/python3.7/site-packages (from recbole) (4.62.3)\r\n", + "Requirement already satisfied: numpy>=1.17.2 in /opt/conda/lib/python3.7/site-packages (from recbole) (1.20.3)\r\n", + "Requirement already satisfied: tensorboard>=2.5.0 in /opt/conda/lib/python3.7/site-packages (from recbole) (2.6.0)\r\n", + "Requirement already satisfied: python-dateutil>=2.7.3 in /opt/conda/lib/python3.7/site-packages (from pandas>=1.0.5->recbole) (2.8.2)\r\n", + "Requirement already satisfied: pytz>=2017.3 in /opt/conda/lib/python3.7/site-packages (from pandas>=1.0.5->recbole) (2021.3)\r\n", + "Requirement already satisfied: joblib>=0.11 in /opt/conda/lib/python3.7/site-packages (from scikit-learn>=0.23.2->recbole) (1.1.0)\r\n", + "Requirement already satisfied: threadpoolctl>=2.0.0 in /opt/conda/lib/python3.7/site-packages (from scikit-learn>=0.23.2->recbole) (3.0.0)\r\n", + "Requirement already satisfied: tensorboard-data-server<0.7.0,>=0.6.0 in /opt/conda/lib/python3.7/site-packages (from tensorboard>=2.5.0->recbole) (0.6.1)\r\n", + "Requirement already satisfied: wheel>=0.26 in /opt/conda/lib/python3.7/site-packages (from tensorboard>=2.5.0->recbole) (0.37.0)\r\n", + "Requirement already satisfied: setuptools>=41.0.0 in /opt/conda/lib/python3.7/site-packages (from tensorboard>=2.5.0->recbole) (59.5.0)\r\n", + "Requirement already satisfied: tensorboard-plugin-wit>=1.6.0 in /opt/conda/lib/python3.7/site-packages (from tensorboard>=2.5.0->recbole) (1.8.0)\r\n", + "Requirement already satisfied: absl-py>=0.4 in /opt/conda/lib/python3.7/site-packages (from tensorboard>=2.5.0->recbole) (0.15.0)\r\n", + "Requirement already satisfied: protobuf>=3.6.0 in /opt/conda/lib/python3.7/site-packages (from tensorboard>=2.5.0->recbole) (3.19.1)\r\n", + "Requirement already satisfied: google-auth<2,>=1.6.3 in /opt/conda/lib/python3.7/site-packages (from tensorboard>=2.5.0->recbole) (1.35.0)\r\n", + "Requirement already satisfied: requests<3,>=2.21.0 in /opt/conda/lib/python3.7/site-packages (from tensorboard>=2.5.0->recbole) (2.26.0)\r\n", + "Requirement already satisfied: grpcio>=1.24.3 in /opt/conda/lib/python3.7/site-packages (from tensorboard>=2.5.0->recbole) (1.43.0)\r\n", + "Requirement already satisfied: werkzeug>=0.11.15 in /opt/conda/lib/python3.7/site-packages (from tensorboard>=2.5.0->recbole) (2.0.2)\r\n", + "Requirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in /opt/conda/lib/python3.7/site-packages (from tensorboard>=2.5.0->recbole) (0.4.6)\r\n", + "Requirement already satisfied: markdown>=2.6.8 in /opt/conda/lib/python3.7/site-packages (from tensorboard>=2.5.0->recbole) (3.3.6)\r\n", + "Requirement already satisfied: typing-extensions in /opt/conda/lib/python3.7/site-packages (from torch>=1.7.0->recbole) (4.0.1)\r\n", + "Requirement already satisfied: six in /opt/conda/lib/python3.7/site-packages (from absl-py>=0.4->tensorboard>=2.5.0->recbole) (1.16.0)\r\n", + "Requirement already satisfied: cachetools<5.0,>=2.0.0 in /opt/conda/lib/python3.7/site-packages (from google-auth<2,>=1.6.3->tensorboard>=2.5.0->recbole) (4.2.4)\r\n", + "Requirement already satisfied: rsa<5,>=3.1.4 in /opt/conda/lib/python3.7/site-packages (from google-auth<2,>=1.6.3->tensorboard>=2.5.0->recbole) (4.8)\r\n", + "Requirement already satisfied: pyasn1-modules>=0.2.1 in /opt/conda/lib/python3.7/site-packages (from google-auth<2,>=1.6.3->tensorboard>=2.5.0->recbole) (0.2.7)\r\n", + "Requirement already satisfied: requests-oauthlib>=0.7.0 in /opt/conda/lib/python3.7/site-packages (from google-auth-oauthlib<0.5,>=0.4.1->tensorboard>=2.5.0->recbole) (1.3.0)\r\n", + "Requirement already satisfied: importlib-metadata>=4.4 in /opt/conda/lib/python3.7/site-packages (from markdown>=2.6.8->tensorboard>=2.5.0->recbole) (4.10.1)\r\n", + "Requirement already satisfied: idna<4,>=2.5 in /opt/conda/lib/python3.7/site-packages (from requests<3,>=2.21.0->tensorboard>=2.5.0->recbole) (3.1)\r\n", + "Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.7/site-packages (from requests<3,>=2.21.0->tensorboard>=2.5.0->recbole) (2021.10.8)\r\n", + "Requirement already satisfied: charset-normalizer~=2.0.0 in /opt/conda/lib/python3.7/site-packages (from requests<3,>=2.21.0->tensorboard>=2.5.0->recbole) (2.0.9)\r\n", + "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /opt/conda/lib/python3.7/site-packages (from requests<3,>=2.21.0->tensorboard>=2.5.0->recbole) (1.26.7)\r\n", + "Requirement already satisfied: zipp>=0.5 in /opt/conda/lib/python3.7/site-packages (from importlib-metadata>=4.4->markdown>=2.6.8->tensorboard>=2.5.0->recbole) (3.6.0)\r\n", + "Requirement already satisfied: pyasn1<0.5.0,>=0.4.6 in /opt/conda/lib/python3.7/site-packages (from pyasn1-modules>=0.2.1->google-auth<2,>=1.6.3->tensorboard>=2.5.0->recbole) (0.4.8)\r\n", + "Requirement already satisfied: oauthlib>=3.0.0 in /opt/conda/lib/python3.7/site-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tensorboard>=2.5.0->recbole) (3.1.1)\r\n", + "Installing collected packages: scipy, colorlog, recbole\r\n", + " Attempting uninstall: scipy\r\n", + " Found existing installation: scipy 1.7.3\r\n", + " Uninstalling scipy-1.7.3:\r\n", + " Successfully uninstalled scipy-1.7.3\r\n", + " Attempting uninstall: colorlog\r\n", + " Found existing installation: colorlog 6.6.0\r\n", + " Uninstalling colorlog-6.6.0:\r\n", + " Successfully uninstalled colorlog-6.6.0\r\n", + "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\r\n", + "yellowbrick 1.3.post1 requires numpy<1.20,>=1.16.0, but you have numpy 1.20.3 which is incompatible.\r\n", + "pdpbox 0.2.1 requires matplotlib==3.1.1, but you have matplotlib 3.5.1 which is incompatible.\r\n", + "imbalanced-learn 0.9.0 requires scikit-learn>=1.0.1, but you have scikit-learn 0.23.2 which is incompatible.\r\n", + "featuretools 1.4.1 requires numpy>=1.21.0, but you have numpy 1.20.3 which is incompatible.\r\n", + "arviz 0.11.4 requires typing-extensions<4,>=3.7.4.3, but you have typing-extensions 4.0.1 which is incompatible.\u001b[0m\r\n", + "Successfully installed colorlog-4.7.2 recbole-1.0.1 scipy-1.6.0\r\n", + "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\r\n" + ] + } + ], + "source": [ + "!pip install recbole" + ] + }, + { + "cell_type": "markdown", + "id": "4afb60af", + "metadata": { + "papermill": { + "duration": 0.067053, + "end_time": "2022-03-20T02:19:36.014583", + "exception": false, + "start_time": "2022-03-20T02:19:35.947530", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "# 1. Create atomic file" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "a1d01d80", + "metadata": { + "execution": { + "iopub.execute_input": "2022-03-20T02:19:36.155095Z", + "iopub.status.busy": "2022-03-20T02:19:36.154434Z", + "iopub.status.idle": "2022-03-20T02:20:43.709687Z", + "shell.execute_reply": "2022-03-20T02:20:43.710329Z", + "shell.execute_reply.started": "2022-03-19T09:03:46.361987Z" + }, + "papermill": { + "duration": 67.62903, + "end_time": "2022-03-20T02:20:43.710549", + "exception": false, + "start_time": "2022-03-20T02:19:36.081519", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
t_datcustomer_idarticle_idpricesales_channel_id
02018-09-20000058a12d5b43e67d225668fa1f8d618c13dc232df0ca...06637130010.0508312
12018-09-20000058a12d5b43e67d225668fa1f8d618c13dc232df0ca...05415180230.0304922
22018-09-2000007d2de826758b65a93dd24ce629ed66842531df6699...05052210040.0152372
32018-09-2000007d2de826758b65a93dd24ce629ed66842531df6699...06856870030.0169322
42018-09-2000007d2de826758b65a93dd24ce629ed66842531df6699...06856870040.0169322
\n", + "
" + ], + "text/plain": [ + " t_dat customer_id article_id \\\n", + "0 2018-09-20 000058a12d5b43e67d225668fa1f8d618c13dc232df0ca... 0663713001 \n", + "1 2018-09-20 000058a12d5b43e67d225668fa1f8d618c13dc232df0ca... 0541518023 \n", + "2 2018-09-20 00007d2de826758b65a93dd24ce629ed66842531df6699... 0505221004 \n", + "3 2018-09-20 00007d2de826758b65a93dd24ce629ed66842531df6699... 0685687003 \n", + "4 2018-09-20 00007d2de826758b65a93dd24ce629ed66842531df6699... 0685687004 \n", + "\n", + " price sales_channel_id \n", + "0 0.050831 2 \n", + "1 0.030492 2 \n", + "2 0.015237 2 \n", + "3 0.016932 2 \n", + "4 0.016932 2 " + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import pandas as pd\n", + "import gc\n", + "df = pd.read_csv(r\"/kaggle/input/h-and-m-personalized-fashion-recommendations/transactions_train.csv\", \n", + " dtype={'article_id': 'str'})\n", + "df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "535f1f8f", + "metadata": { + "execution": { + "iopub.execute_input": "2022-03-20T02:20:45.628510Z", + "iopub.status.busy": "2022-03-20T02:20:45.627430Z", + "iopub.status.idle": "2022-03-20T02:20:50.180018Z", + "shell.execute_reply": "2022-03-20T02:20:50.179424Z", + "shell.execute_reply.started": "2022-03-19T09:04:50.295558Z" + }, + "papermill": { + "duration": 6.4014, + "end_time": "2022-03-20T02:20:50.180176", + "exception": false, + "start_time": "2022-03-20T02:20:43.778776", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
t_datcustomer_idarticle_idpricesales_channel_id
02018-09-20000058a12d5b43e67d225668fa1f8d618c13dc232df0ca...06637130010.0508312
12018-09-20000058a12d5b43e67d225668fa1f8d618c13dc232df0ca...05415180230.0304922
22018-09-2000007d2de826758b65a93dd24ce629ed66842531df6699...05052210040.0152372
32018-09-2000007d2de826758b65a93dd24ce629ed66842531df6699...06856870030.0169322
42018-09-2000007d2de826758b65a93dd24ce629ed66842531df6699...06856870040.0169322
..................
317883192020-09-22fff2282977442e327b45d8c89afde25617d00124d0f999...09295110010.0593052
317883202020-09-22fff2282977442e327b45d8c89afde25617d00124d0f999...08913220040.0423562
317883212020-09-22fff380805474b287b05cb2a7507b9a013482f7dd0bce0e...09183250010.0432031
317883222020-09-22fff4d3a8b1f3b60af93e78c30a7cb4cf75edaf2590d3e5...08334590020.0067631
317883232020-09-22fffef3b6b73545df065b521e19f64bf6fe93bfd450ab20...08985730030.0338812
\n", + "

31788324 rows × 5 columns

\n", + "
" + ], + "text/plain": [ + " t_dat customer_id \\\n", + "0 2018-09-20 000058a12d5b43e67d225668fa1f8d618c13dc232df0ca... \n", + "1 2018-09-20 000058a12d5b43e67d225668fa1f8d618c13dc232df0ca... \n", + "2 2018-09-20 00007d2de826758b65a93dd24ce629ed66842531df6699... \n", + "3 2018-09-20 00007d2de826758b65a93dd24ce629ed66842531df6699... \n", + "4 2018-09-20 00007d2de826758b65a93dd24ce629ed66842531df6699... \n", + "... ... ... \n", + "31788319 2020-09-22 fff2282977442e327b45d8c89afde25617d00124d0f999... \n", + "31788320 2020-09-22 fff2282977442e327b45d8c89afde25617d00124d0f999... \n", + "31788321 2020-09-22 fff380805474b287b05cb2a7507b9a013482f7dd0bce0e... \n", + "31788322 2020-09-22 fff4d3a8b1f3b60af93e78c30a7cb4cf75edaf2590d3e5... \n", + "31788323 2020-09-22 fffef3b6b73545df065b521e19f64bf6fe93bfd450ab20... \n", + "\n", + " article_id price sales_channel_id \n", + "0 0663713001 0.050831 2 \n", + "1 0541518023 0.030492 2 \n", + "2 0505221004 0.015237 2 \n", + "3 0685687003 0.016932 2 \n", + "4 0685687004 0.016932 2 \n", + "... ... ... ... \n", + "31788319 0929511001 0.059305 2 \n", + "31788320 0891322004 0.042356 2 \n", + "31788321 0918325001 0.043203 1 \n", + "31788322 0833459002 0.006763 1 \n", + "31788323 0898573003 0.033881 2 \n", + "\n", + "[31788324 rows x 5 columns]" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df['t_dat'] = pd.to_datetime(df['t_dat'], format=\"%Y-%m-%d\")\n", + "df" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "28043b68", + "metadata": { + "execution": { + "iopub.execute_input": "2022-03-20T02:20:50.327100Z", + "iopub.status.busy": "2022-03-20T02:20:50.325671Z", + "iopub.status.idle": "2022-03-20T02:20:50.923223Z", + "shell.execute_reply": "2022-03-20T02:20:50.923781Z", + "shell.execute_reply.started": "2022-03-19T09:04:56.537845Z" + }, + "papermill": { + "duration": 0.673267, + "end_time": "2022-03-20T02:20:50.923977", + "exception": false, + "start_time": "2022-03-20T02:20:50.250710", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
t_datcustomer_idarticle_idpricesales_channel_idtimestamp
02018-09-20000058a12d5b43e67d225668fa1f8d618c13dc232df0ca...06637130010.05083121537401600
12018-09-20000058a12d5b43e67d225668fa1f8d618c13dc232df0ca...05415180230.03049221537401600
22018-09-2000007d2de826758b65a93dd24ce629ed66842531df6699...05052210040.01523721537401600
32018-09-2000007d2de826758b65a93dd24ce629ed66842531df6699...06856870030.01693221537401600
42018-09-2000007d2de826758b65a93dd24ce629ed66842531df6699...06856870040.01693221537401600
\n", + "
" + ], + "text/plain": [ + " t_dat customer_id article_id \\\n", + "0 2018-09-20 000058a12d5b43e67d225668fa1f8d618c13dc232df0ca... 0663713001 \n", + "1 2018-09-20 000058a12d5b43e67d225668fa1f8d618c13dc232df0ca... 0541518023 \n", + "2 2018-09-20 00007d2de826758b65a93dd24ce629ed66842531df6699... 0505221004 \n", + "3 2018-09-20 00007d2de826758b65a93dd24ce629ed66842531df6699... 0685687003 \n", + "4 2018-09-20 00007d2de826758b65a93dd24ce629ed66842531df6699... 0685687004 \n", + "\n", + " price sales_channel_id timestamp \n", + "0 0.050831 2 1537401600 \n", + "1 0.030492 2 1537401600 \n", + "2 0.015237 2 1537401600 \n", + "3 0.016932 2 1537401600 \n", + "4 0.016932 2 1537401600 " + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import numpy as np\n", + "df['timestamp'] = df.t_dat.values.astype(np.int64) // 10 ** 9\n", + "df.head()" + ] + }, + { + "cell_type": "markdown", + "id": "2b77607f", + "metadata": { + "papermill": { + "duration": 0.06931, + "end_time": "2022-03-20T02:20:51.061588", + "exception": false, + "start_time": "2022-03-20T02:20:50.992278", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "**We fill with data in only 2020(timestapm > > 1585620000) and create inter file**\n", + "For anyone need instruction about inter file, please check below links:\n", + "* https://recbole.io/docs/user_guide/data_intro.html\n", + "* https://recbole.io/docs/user_guide/data/atomic_files.html" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "892d6c5d", + "metadata": { + "execution": { + "iopub.execute_input": "2022-03-20T02:20:51.240531Z", + "iopub.status.busy": "2022-03-20T02:20:51.237724Z", + "iopub.status.idle": "2022-03-20T02:20:53.162136Z", + "shell.execute_reply": "2022-03-20T02:20:53.161463Z", + "shell.execute_reply.started": "2022-03-19T09:04:57.127124Z" + }, + "papermill": { + "duration": 2.032094, + "end_time": "2022-03-20T02:20:53.162319", + "exception": false, + "start_time": "2022-03-20T02:20:51.130225", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
user_id:tokenitem_id:tokentimestamp:float
23934157000058a12d5b43e67d225668fa1f8d618c13dc232df0ca...07278080011585699200
23934158000058a12d5b43e67d225668fa1f8d618c13dc232df0ca...07278080071585699200
23934159000563485cbb7850b0a93c6606f89c5b961c6647d1bd48...05675320151585699200
23934160000563485cbb7850b0a93c6606f89c5b961c6647d1bd48...07061040091585699200
2393416100083cda041544b2fbb0e0d2905ad17da7cf1007526fb4...07835040041585699200
............
31788319fff2282977442e327b45d8c89afde25617d00124d0f999...09295110011600732800
31788320fff2282977442e327b45d8c89afde25617d00124d0f999...08913220041600732800
31788321fff380805474b287b05cb2a7507b9a013482f7dd0bce0e...09183250011600732800
31788322fff4d3a8b1f3b60af93e78c30a7cb4cf75edaf2590d3e5...08334590021600732800
31788323fffef3b6b73545df065b521e19f64bf6fe93bfd450ab20...08985730031600732800
\n", + "

7854167 rows × 3 columns

\n", + "
" + ], + "text/plain": [ + " user_id:token item_id:token \\\n", + "23934157 000058a12d5b43e67d225668fa1f8d618c13dc232df0ca... 0727808001 \n", + "23934158 000058a12d5b43e67d225668fa1f8d618c13dc232df0ca... 0727808007 \n", + "23934159 000563485cbb7850b0a93c6606f89c5b961c6647d1bd48... 0567532015 \n", + "23934160 000563485cbb7850b0a93c6606f89c5b961c6647d1bd48... 0706104009 \n", + "23934161 00083cda041544b2fbb0e0d2905ad17da7cf1007526fb4... 0783504004 \n", + "... ... ... \n", + "31788319 fff2282977442e327b45d8c89afde25617d00124d0f999... 0929511001 \n", + "31788320 fff2282977442e327b45d8c89afde25617d00124d0f999... 0891322004 \n", + "31788321 fff380805474b287b05cb2a7507b9a013482f7dd0bce0e... 0918325001 \n", + "31788322 fff4d3a8b1f3b60af93e78c30a7cb4cf75edaf2590d3e5... 0833459002 \n", + "31788323 fffef3b6b73545df065b521e19f64bf6fe93bfd450ab20... 0898573003 \n", + "\n", + " timestamp:float \n", + "23934157 1585699200 \n", + "23934158 1585699200 \n", + "23934159 1585699200 \n", + "23934160 1585699200 \n", + "23934161 1585699200 \n", + "... ... \n", + "31788319 1600732800 \n", + "31788320 1600732800 \n", + "31788321 1600732800 \n", + "31788322 1600732800 \n", + "31788323 1600732800 \n", + "\n", + "[7854167 rows x 3 columns]" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "temp = df[df['timestamp'] > 1585620000][['customer_id', 'article_id', 'timestamp']].rename(\n", + " columns={'customer_id': 'user_id:token', 'article_id': 'item_id:token', 'timestamp': 'timestamp:float'})\n", + "temp" + ] + }, + { + "cell_type": "markdown", + "id": "ae5d0ec3", + "metadata": { + "papermill": { + "duration": 0.071609, + "end_time": "2022-03-20T02:20:53.304714", + "exception": false, + "start_time": "2022-03-20T02:20:53.233105", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "We save atomic file in dataset format for using with recbole" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "37f98c58", + "metadata": { + "execution": { + "iopub.execute_input": "2022-03-20T02:20:53.493765Z", + "iopub.status.busy": "2022-03-20T02:20:53.454009Z", + "iopub.status.idle": "2022-03-20T02:21:31.746398Z", + "shell.execute_reply": "2022-03-20T02:21:31.746994Z", + "shell.execute_reply.started": "2022-03-19T09:04:58.814912Z" + }, + "papermill": { + "duration": 38.371105, + "end_time": "2022-03-20T02:21:31.747188", + "exception": false, + "start_time": "2022-03-20T02:20:53.376083", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "160" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "!mkdir /kaggle/working/recbox_data\n", + "temp.to_csv('/kaggle/working/recbox_data/recbox_data.inter', index=False, sep='\\t')\n", + "del temp\n", + "gc.collect()" + ] + }, + { + "cell_type": "markdown", + "id": "2c0c4ed1", + "metadata": { + "papermill": { + "duration": 0.069772, + "end_time": "2022-03-20T02:21:31.887123", + "exception": false, + "start_time": "2022-03-20T02:21:31.817351", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "# 2. We create defautl recomendation for user who can not be predicted by sequential model.\n", + "I use this approach in notebook: https://www.kaggle.com/hervind/h-m-faster-trending-products-weekly You can check it for more detail information. I will juse copy only code here" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "f06dc674", + "metadata": { + "execution": { + "iopub.execute_input": "2022-03-20T02:21:32.036991Z", + "iopub.status.busy": "2022-03-20T02:21:32.035445Z", + "iopub.status.idle": "2022-03-20T02:21:32.037770Z", + "shell.execute_reply": "2022-03-20T02:21:32.038388Z", + "shell.execute_reply.started": "2022-03-19T09:05:34.218845Z" + }, + "papermill": { + "duration": 0.079481, + "end_time": "2022-03-20T02:21:32.038570", + "exception": false, + "start_time": "2022-03-20T02:21:31.959089", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "import os\n", + "import numpy as np\n", + "import pandas as pd" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "c6fb55ab", + "metadata": { + "execution": { + "iopub.execute_input": "2022-03-20T02:21:32.186012Z", + "iopub.status.busy": "2022-03-20T02:21:32.185168Z", + "iopub.status.idle": "2022-03-20T02:22:25.689493Z", + "shell.execute_reply": "2022-03-20T02:22:25.690149Z", + "shell.execute_reply.started": "2022-03-19T09:05:34.226052Z" + }, + "papermill": { + "duration": 53.581364, + "end_time": "2022-03-20T02:22:25.690334", + "exception": false, + "start_time": "2022-03-20T02:21:32.108970", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "((1371980, 2), (1371980, 2), (1371980, 2))" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sub0 = pd.read_csv('../input/hm-pre-recommendation/submissio_byfone_chris.csv').sort_values('customer_id').reset_index(drop=True)\n", + "sub1 = pd.read_csv('../input/hm-pre-recommendation/submission_trending.csv').sort_values('customer_id').reset_index(drop=True)\n", + "sub2 = pd.read_csv('../input/hm-pre-recommendation/submission_exponential_decay.csv').sort_values('customer_id').reset_index(drop=True)\n", + "\n", + "sub0.shape, sub1.shape, sub2.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "502968dc", + "metadata": { + "execution": { + "iopub.execute_input": "2022-03-20T02:22:25.839198Z", + "iopub.status.busy": "2022-03-20T02:22:25.837927Z", + "iopub.status.idle": "2022-03-20T02:22:26.105094Z", + "shell.execute_reply": "2022-03-20T02:22:26.105816Z", + "shell.execute_reply.started": "2022-03-19T09:06:27.429248Z" + }, + "papermill": { + "duration": 0.344956, + "end_time": "2022-03-20T02:22:26.106058", + "exception": false, + "start_time": "2022-03-20T02:22:25.761102", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
customer_idprediction0prediction1prediction2
000000dbacae5abe5e23885899a1fa44253a17956c6d1c3...0568601043 0568601006 0656719005 0745232001 07...0568601043 0568601006 0656719005 0745232001 07...0568601043 0924243001 0924243002 0918522001 07...
10000423b00ade91418cceaf3b26c6af3dd342b51fd051e...0826211002 0800436010 0739590027 0723529001 08...0826211002 0800436010 0739590027 0723529001 08...0924243001 0924243002 0918522001 0751471001 04...
2000058a12d5b43e67d225668fa1f8d618c13dc232df0ca...0794321007 0852643001 0852643003 0858883002 07...0794321007 0852643001 0852643003 0858883002 07...0794321007 0924243001 0924243002 0918522001 07...
300005ca1c9ed5f5146b52ac8639a40ca9d57aeff4d1bd2...0448509014 0573085028 0751471001 0706016001 06...0448509014 0573085028 0751471001 0706016001 06...0924243001 0924243002 0918522001 0751471001 04...
400006413d8573cd20ed7128e53b7b13819fe5cfc2d801f...0730683050 0791587015 0896152002 0818320001 09...0730683050 0791587015 0896152002 0818320001 09...0924243001 0924243002 0918522001 0751471001 04...
\n", + "
" + ], + "text/plain": [ + " customer_id \\\n", + "0 00000dbacae5abe5e23885899a1fa44253a17956c6d1c3... \n", + "1 0000423b00ade91418cceaf3b26c6af3dd342b51fd051e... \n", + "2 000058a12d5b43e67d225668fa1f8d618c13dc232df0ca... \n", + "3 00005ca1c9ed5f5146b52ac8639a40ca9d57aeff4d1bd2... \n", + "4 00006413d8573cd20ed7128e53b7b13819fe5cfc2d801f... \n", + "\n", + " prediction0 \\\n", + "0 0568601043 0568601006 0656719005 0745232001 07... \n", + "1 0826211002 0800436010 0739590027 0723529001 08... \n", + "2 0794321007 0852643001 0852643003 0858883002 07... \n", + "3 0448509014 0573085028 0751471001 0706016001 06... \n", + "4 0730683050 0791587015 0896152002 0818320001 09... \n", + "\n", + " prediction1 \\\n", + "0 0568601043 0568601006 0656719005 0745232001 07... \n", + "1 0826211002 0800436010 0739590027 0723529001 08... \n", + "2 0794321007 0852643001 0852643003 0858883002 07... \n", + "3 0448509014 0573085028 0751471001 0706016001 06... \n", + "4 0730683050 0791587015 0896152002 0818320001 09... \n", + "\n", + " prediction2 \n", + "0 0568601043 0924243001 0924243002 0918522001 07... \n", + "1 0924243001 0924243002 0918522001 0751471001 04... \n", + "2 0794321007 0924243001 0924243002 0918522001 07... \n", + "3 0924243001 0924243002 0918522001 0751471001 04... \n", + "4 0924243001 0924243002 0918522001 0751471001 04... " + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sub0.columns = ['customer_id', 'prediction0']\n", + "sub0['prediction1'] = sub1['prediction']\n", + "sub0['prediction2'] = sub2['prediction']\n", + "del sub1, sub2\n", + "gc.collect()\n", + "sub0.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "81843509", + "metadata": { + "execution": { + "iopub.execute_input": "2022-03-20T02:22:26.383905Z", + "iopub.status.busy": "2022-03-20T02:22:26.254160Z", + "iopub.status.idle": "2022-03-20T02:26:05.565638Z", + "shell.execute_reply": "2022-03-20T02:26:05.566210Z", + "shell.execute_reply.started": "2022-03-19T09:06:27.705565Z" + }, + "papermill": { + "duration": 219.388709, + "end_time": "2022-03-20T02:26:05.566402", + "exception": false, + "start_time": "2022-03-20T02:22:26.177693", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
customer_idprediction0prediction1prediction2prediction
000000dbacae5abe5e23885899a1fa44253a17956c6d1c3...0568601043 0568601006 0656719005 0745232001 07...0568601043 0568601006 0656719005 0745232001 07...0568601043 0924243001 0924243002 0918522001 07...0568601043 0568601006 0656719005 0745232001 09...
10000423b00ade91418cceaf3b26c6af3dd342b51fd051e...0826211002 0800436010 0739590027 0723529001 08...0826211002 0800436010 0739590027 0723529001 08...0924243001 0924243002 0918522001 0751471001 04...0826211002 0800436010 0924243001 0739590027 07...
2000058a12d5b43e67d225668fa1f8d618c13dc232df0ca...0794321007 0852643001 0852643003 0858883002 07...0794321007 0852643001 0852643003 0858883002 07...0794321007 0924243001 0924243002 0918522001 07...0794321007 0852643001 0852643003 0858883002 09...
300005ca1c9ed5f5146b52ac8639a40ca9d57aeff4d1bd2...0448509014 0573085028 0751471001 0706016001 06...0448509014 0573085028 0751471001 0706016001 06...0924243001 0924243002 0918522001 0751471001 04...0448509014 0573085028 0924243001 0751471001 07...
400006413d8573cd20ed7128e53b7b13819fe5cfc2d801f...0730683050 0791587015 0896152002 0818320001 09...0730683050 0791587015 0896152002 0818320001 09...0924243001 0924243002 0918522001 0751471001 04...0730683050 0791587015 0924243001 0896152002 08...
\n", + "
" + ], + "text/plain": [ + " customer_id \\\n", + "0 00000dbacae5abe5e23885899a1fa44253a17956c6d1c3... \n", + "1 0000423b00ade91418cceaf3b26c6af3dd342b51fd051e... \n", + "2 000058a12d5b43e67d225668fa1f8d618c13dc232df0ca... \n", + "3 00005ca1c9ed5f5146b52ac8639a40ca9d57aeff4d1bd2... \n", + "4 00006413d8573cd20ed7128e53b7b13819fe5cfc2d801f... \n", + "\n", + " prediction0 \\\n", + "0 0568601043 0568601006 0656719005 0745232001 07... \n", + "1 0826211002 0800436010 0739590027 0723529001 08... \n", + "2 0794321007 0852643001 0852643003 0858883002 07... \n", + "3 0448509014 0573085028 0751471001 0706016001 06... \n", + "4 0730683050 0791587015 0896152002 0818320001 09... \n", + "\n", + " prediction1 \\\n", + "0 0568601043 0568601006 0656719005 0745232001 07... \n", + "1 0826211002 0800436010 0739590027 0723529001 08... \n", + "2 0794321007 0852643001 0852643003 0858883002 07... \n", + "3 0448509014 0573085028 0751471001 0706016001 06... \n", + "4 0730683050 0791587015 0896152002 0818320001 09... \n", + "\n", + " prediction2 \\\n", + "0 0568601043 0924243001 0924243002 0918522001 07... \n", + "1 0924243001 0924243002 0918522001 0751471001 04... \n", + "2 0794321007 0924243001 0924243002 0918522001 07... \n", + "3 0924243001 0924243002 0918522001 0751471001 04... \n", + "4 0924243001 0924243002 0918522001 0751471001 04... \n", + "\n", + " prediction \n", + "0 0568601043 0568601006 0656719005 0745232001 09... \n", + "1 0826211002 0800436010 0924243001 0739590027 07... \n", + "2 0794321007 0852643001 0852643003 0858883002 09... \n", + "3 0448509014 0573085028 0924243001 0751471001 07... \n", + "4 0730683050 0791587015 0924243001 0896152002 08... " + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "def cust_blend(dt, W = [1,1,1]):\n", + " #Global ensemble weights\n", + " #W = [1.15,0.95,0.85]\n", + " \n", + " #Create a list of all model predictions\n", + " REC = []\n", + " REC.append(dt['prediction0'].split())\n", + " REC.append(dt['prediction1'].split())\n", + " REC.append(dt['prediction2'].split())\n", + " \n", + " #Create a dictionary of items recommended. \n", + " #Assign a weight according the order of appearance and multiply by global weights\n", + " res = {}\n", + " for M in range(len(REC)):\n", + " for n, v in enumerate(REC[M]):\n", + " if v in res:\n", + " res[v] += (W[M]/(n+1))\n", + " else:\n", + " res[v] = (W[M]/(n+1))\n", + " \n", + " # Sort dictionary by item weights\n", + " res = list(dict(sorted(res.items(), key=lambda item: -item[1])).keys())\n", + " \n", + " # Return the top 12 itens only\n", + " return ' '.join(res[:12])\n", + "\n", + "sub0['prediction'] = sub0.apply(cust_blend, W = [1.05,1.00,0.95], axis=1)\n", + "sub0.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "0d39bc15", + "metadata": { + "execution": { + "iopub.execute_input": "2022-03-20T02:26:05.807778Z", + "iopub.status.busy": "2022-03-20T02:26:05.806566Z", + "iopub.status.idle": "2022-03-20T02:26:17.032326Z", + "shell.execute_reply": "2022-03-20T02:26:17.031696Z", + "shell.execute_reply.started": "2022-03-19T09:09:54.008829Z" + }, + "papermill": { + "duration": 11.394565, + "end_time": "2022-03-20T02:26:17.032497", + "exception": false, + "start_time": "2022-03-20T02:26:05.637932", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "del sub0['prediction0']\n", + "del sub0['prediction1']\n", + "del sub0['prediction2']\n", + "gc.collect()\n", + "sub0.to_csv(f'submission.csv', index=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "572a907a", + "metadata": { + "execution": { + "iopub.execute_input": "2022-03-20T02:26:17.283022Z", + "iopub.status.busy": "2022-03-20T02:26:17.281031Z", + "iopub.status.idle": "2022-03-20T02:26:17.287735Z", + "shell.execute_reply": "2022-03-20T02:26:17.287053Z", + "shell.execute_reply.started": "2022-03-19T09:10:04.696284Z" + }, + "papermill": { + "duration": 0.183306, + "end_time": "2022-03-20T02:26:17.287888", + "exception": false, + "start_time": "2022-03-20T02:26:17.104582", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "21" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "del sub0\n", + "del df\n", + "gc.collect()" + ] + }, + { + "cell_type": "markdown", + "id": "91a7316f", + "metadata": { + "papermill": { + "duration": 0.072866, + "end_time": "2022-03-20T02:26:17.434227", + "exception": false, + "start_time": "2022-03-20T02:26:17.361361", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "# 3. Create dataset and train model with Recbole\n", + "\n", + "For anyone need instruction document, please check this link: https://recbole.io/docs/user_guide/usage/use_modules.html" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "861094fc", + "metadata": { + "execution": { + "iopub.execute_input": "2022-03-20T02:26:17.582780Z", + "iopub.status.busy": "2022-03-20T02:26:17.581666Z", + "iopub.status.idle": "2022-03-20T02:26:20.280026Z", + "shell.execute_reply": "2022-03-20T02:26:20.280632Z", + "shell.execute_reply.started": "2022-03-19T09:10:04.781373Z" + }, + "papermill": { + "duration": 2.773509, + "end_time": "2022-03-20T02:26:20.280830", + "exception": false, + "start_time": "2022-03-20T02:26:17.507321", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "import logging\n", + "from logging import getLogger\n", + "from recbole.config import Config\n", + "from recbole.data import create_dataset, data_preparation\n", + "from recbole.model.sequential_recommender import GRU4Rec\n", + "from recbole.trainer import Trainer\n", + "from recbole.utils import init_seed, init_logger" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "c211ad36", + "metadata": { + "execution": { + "iopub.execute_input": "2022-03-20T02:26:20.438539Z", + "iopub.status.busy": "2022-03-20T02:26:20.437533Z", + "iopub.status.idle": "2022-03-20T02:26:21.300145Z", + "shell.execute_reply": "2022-03-20T02:26:21.007505Z", + "shell.execute_reply.started": "2022-03-19T09:10:07.54699Z" + }, + "papermill": { + "duration": 0.946555, + "end_time": "2022-03-20T02:26:21.300340", + "exception": false, + "start_time": "2022-03-20T02:26:20.353785", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "General Hyper Parameters:\n", + "gpu_id = 0\n", + "use_gpu = True\n", + "seed = 2020\n", + "state = INFO\n", + "reproducibility = True\n", + "data_path = /kaggle/working/recbox_data\n", + "checkpoint_dir = saved\n", + "show_progress = True\n", + "save_dataset = False\n", + "dataset_save_path = None\n", + "save_dataloaders = False\n", + "dataloaders_save_path = None\n", + "log_wandb = False\n", + "\n", + "Training Hyper Parameters:\n", + "epochs = 50\n", + "train_batch_size = 2048\n", + "learner = adam\n", + "learning_rate = 0.001\n", + "neg_sampling = None\n", + "eval_step = 1\n", + "stopping_step = 10\n", + "clip_grad_norm = None\n", + "weight_decay = 0.0\n", + "loss_decimal_place = 4\n", + "\n", + "Evaluation Hyper Parameters:\n", + "eval_args = {'split': {'RS': [10, 0, 0]}, 'group_by': 'user', 'order': 'TO', 'mode': 'full'}\n", + "repeatable = True\n", + "metrics = ['Recall', 'MRR', 'NDCG', 'Hit', 'Precision']\n", + "topk = [10]\n", + "valid_metric = MRR@10\n", + "valid_metric_bigger = True\n", + "eval_batch_size = 4096\n", + "metric_decimal_place = 4\n", + "\n", + "Dataset Hyper Parameters:\n", + "field_separator = \t\n", + "seq_separator = \n", + "USER_ID_FIELD = user_id\n", + "ITEM_ID_FIELD = item_id\n", + "RATING_FIELD = rating\n", + "TIME_FIELD = timestamp\n", + "seq_len = None\n", + "LABEL_FIELD = label\n", + "threshold = None\n", + "NEG_PREFIX = neg_\n", + "load_col = {'inter': ['user_id', 'item_id', 'timestamp']}\n", + "unload_col = None\n", + "unused_col = None\n", + "additional_feat_suffix = None\n", + "rm_dup_inter = None\n", + "val_interval = None\n", + "filter_inter_by_user_or_item = True\n", + "user_inter_num_interval = [30,inf)\n", + "item_inter_num_interval = [40,inf)\n", + "alias_of_user_id = None\n", + "alias_of_item_id = None\n", + "alias_of_entity_id = None\n", + "alias_of_relation_id = None\n", + "preload_weight = None\n", + "normalize_field = None\n", + "normalize_all = None\n", + "ITEM_LIST_LENGTH_FIELD = item_length\n", + "LIST_SUFFIX = _list\n", + "MAX_ITEM_LIST_LENGTH = 50\n", + "POSITION_FIELD = position_id\n", + "HEAD_ENTITY_ID_FIELD = head_id\n", + "TAIL_ENTITY_ID_FIELD = tail_id\n", + "RELATION_ID_FIELD = relation_id\n", + "ENTITY_ID_FIELD = entity_id\n", + "benchmark_filename = None\n", + "\n", + "Other Hyper Parameters: \n", + "wandb_project = recbole\n", + "require_pow = False\n", + "embedding_size = 64\n", + "hidden_size = 128\n", + "num_layers = 1\n", + "dropout_prob = 0.3\n", + "loss_type = CE\n", + "MODEL_TYPE = ModelType.SEQUENTIAL\n", + "MODEL_INPUT_TYPE = InputType.POINTWISE\n", + "eval_type = EvaluatorType.RANKING\n", + "device = cuda\n", + "train_neg_sample_args = {'strategy': 'none'}\n", + "eval_neg_sample_args = {'strategy': 'full', 'distribution': 'uniform'}\n", + "\n", + "\n" + ] + } + ], + "source": [ + "parameter_dict = {\n", + " 'data_path': '/kaggle/working',\n", + " 'USER_ID_FIELD': 'user_id',\n", + " 'ITEM_ID_FIELD': 'item_id',\n", + " 'TIME_FIELD': 'timestamp',\n", + " 'user_inter_num_interval': \"[30,inf)\",\n", + " 'item_inter_num_interval': \"[40,inf)\",\n", + " 'load_col': {'inter': ['user_id', 'item_id', 'timestamp']},\n", + " 'neg_sampling': None,\n", + " 'epochs': 50,\n", + " 'eval_args': {\n", + " 'split': {'RS': [10, 0, 0]},\n", + " 'group_by': 'user',\n", + " 'order': 'TO',\n", + " 'mode': 'full'}\n", + "}\n", + "\n", + "config = Config(model='GRU4Rec', dataset='recbox_data', config_dict=parameter_dict)\n", + "\n", + "# init random seed\n", + "init_seed(config['seed'], config['reproducibility'])\n", + "\n", + "# logger initialization\n", + "init_logger(config)\n", + "logger = getLogger()\n", + "# Create handlers\n", + "c_handler = logging.StreamHandler()\n", + "c_handler.setLevel(logging.INFO)\n", + "logger.addHandler(c_handler)\n", + "\n", + "# write config info into log\n", + "logger.info(config)" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "548e2d03", + "metadata": { + "execution": { + "iopub.execute_input": "2022-03-20T02:26:21.722138Z", + "iopub.status.busy": "2022-03-20T02:26:21.720998Z", + "iopub.status.idle": "2022-03-20T02:27:55.478058Z", + "shell.execute_reply": "2022-03-20T02:27:55.460191Z", + "shell.execute_reply.started": "2022-03-19T09:10:08.225437Z" + }, + "papermill": { + "duration": 93.969404, + "end_time": "2022-03-20T02:27:55.478219", + "exception": false, + "start_time": "2022-03-20T02:26:21.508815", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "recbox_data\n", + "The number of users: 38916\n", + "Average actions of users: 47.47241423615572\n", + "The number of items: 10962\n", + "Average actions of items: 168.54201259009216\n", + "The number of inters: 1847389\n", + "The sparsity of the dataset: 99.56694768867584%\n", + "Remain Fields: ['user_id', 'item_id', 'timestamp']\n" + ] + } + ], + "source": [ + "dataset = create_dataset(config)\n", + "logger.info(dataset)" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "7c74775e", + "metadata": { + "execution": { + "iopub.execute_input": "2022-03-20T02:27:55.906921Z", + "iopub.status.busy": "2022-03-20T02:27:55.906288Z", + "iopub.status.idle": "2022-03-20T02:28:37.269123Z", + "shell.execute_reply": "2022-03-20T02:28:37.252513Z", + "shell.execute_reply.started": "2022-03-19T09:11:36.745731Z" + }, + "papermill": { + "duration": 41.580573, + "end_time": "2022-03-20T02:28:37.269274", + "exception": false, + "start_time": "2022-03-20T02:27:55.688701", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[Training]: train_batch_size = [2048] negative sampling: [None]\n", + "[Evaluation]: eval_batch_size = [4096] eval_args: [{'split': {'RS': [10, 0, 0]}, 'group_by': 'user', 'order': 'TO', 'mode': 'full'}]\n" + ] + } + ], + "source": [ + "# dataset splitting\n", + "train_data, valid_data, test_data = data_preparation(config, dataset)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "9645ae09", + "metadata": { + "execution": { + "iopub.execute_input": "2022-03-20T02:28:37.715732Z", + "iopub.status.busy": "2022-03-20T02:28:37.714892Z", + "iopub.status.idle": "2022-03-20T02:50:21.046599Z", + "shell.execute_reply": "2022-03-20T02:50:21.045317Z", + "shell.execute_reply.started": "2022-03-19T09:12:13.351077Z" + }, + "papermill": { + "duration": 1303.559679, + "end_time": "2022-03-20T02:50:21.046804", + "exception": false, + "start_time": "2022-03-20T02:28:37.487125", + "status": "completed" + }, + "scrolled": true, + "tags": [] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "GRU4Rec(\n", + " (item_embedding): Embedding(10962, 64, padding_idx=0)\n", + " (emb_dropout): Dropout(p=0.3, inplace=False)\n", + " (gru_layers): GRU(64, 128, bias=False, batch_first=True)\n", + " (dense): Linear(in_features=128, out_features=64, bias=True)\n", + " (loss_fct): CrossEntropyLoss()\n", + ")\n", + "Trainable parameters: 783552\n", + "epoch 0 training [time: 29.21s, train loss: 7608.2684]\n", + "Saving current: saved/GRU4Rec-Mar-20-2022_02-28-47.pth\n", + "epoch 1 training [time: 26.44s, train loss: 7102.8474]\n", + "Saving current: saved/GRU4Rec-Mar-20-2022_02-28-47.pth\n", + "epoch 2 training [time: 25.82s, train loss: 6864.3110]\n", + "Saving current: saved/GRU4Rec-Mar-20-2022_02-28-47.pth\n", + "epoch 3 training [time: 25.81s, train loss: 6658.3106]\n", + "Saving current: saved/GRU4Rec-Mar-20-2022_02-28-47.pth\n", + "epoch 4 training [time: 25.81s, train loss: 6516.7922]\n", + "Saving current: saved/GRU4Rec-Mar-20-2022_02-28-47.pth\n", + "epoch 5 training [time: 25.82s, train loss: 6418.4797]\n", + "Saving current: saved/GRU4Rec-Mar-20-2022_02-28-47.pth\n", + "epoch 6 training [time: 25.55s, train loss: 6338.6159]\n", + "Saving current: saved/GRU4Rec-Mar-20-2022_02-28-47.pth\n", + "epoch 7 training [time: 26.08s, train loss: 6273.0269]\n", + "Saving current: saved/GRU4Rec-Mar-20-2022_02-28-47.pth\n", + "epoch 8 training [time: 25.75s, train loss: 6216.3229]\n", + "Saving current: saved/GRU4Rec-Mar-20-2022_02-28-47.pth\n", + "epoch 9 training [time: 25.85s, train loss: 6168.1504]\n", + "Saving current: saved/GRU4Rec-Mar-20-2022_02-28-47.pth\n", + "epoch 10 training [time: 25.45s, train loss: 6125.8122]\n", + "Saving current: saved/GRU4Rec-Mar-20-2022_02-28-47.pth\n", + "epoch 11 training [time: 25.87s, train loss: 6088.1390]\n", + "Saving current: saved/GRU4Rec-Mar-20-2022_02-28-47.pth\n", + "epoch 12 training [time: 25.71s, train loss: 6056.3469]\n", + "Saving current: saved/GRU4Rec-Mar-20-2022_02-28-47.pth\n", + "epoch 13 training [time: 25.75s, train loss: 6028.3020]\n", + "Saving current: saved/GRU4Rec-Mar-20-2022_02-28-47.pth\n", + "epoch 14 training [time: 25.52s, train loss: 6002.9929]\n", + "Saving current: saved/GRU4Rec-Mar-20-2022_02-28-47.pth\n", + "epoch 15 training [time: 25.74s, train loss: 5981.4722]\n", + "Saving current: saved/GRU4Rec-Mar-20-2022_02-28-47.pth\n", + "epoch 16 training [time: 25.50s, train loss: 5962.1978]\n", + "Saving current: saved/GRU4Rec-Mar-20-2022_02-28-47.pth\n", + "epoch 17 training [time: 25.69s, train loss: 5946.1110]\n", + "Saving current: saved/GRU4Rec-Mar-20-2022_02-28-47.pth\n", + "epoch 18 training [time: 25.66s, train loss: 5931.8931]\n", + "Saving current: saved/GRU4Rec-Mar-20-2022_02-28-47.pth\n", + "epoch 19 training [time: 25.61s, train loss: 5918.5745]\n", + "Saving current: saved/GRU4Rec-Mar-20-2022_02-28-47.pth\n", + "epoch 20 training [time: 25.33s, train loss: 5907.3241]\n", + "Saving current: saved/GRU4Rec-Mar-20-2022_02-28-47.pth\n", + "epoch 21 training [time: 25.86s, train loss: 5897.1385]\n", + "Saving current: saved/GRU4Rec-Mar-20-2022_02-28-47.pth\n", + "epoch 22 training [time: 25.55s, train loss: 5887.6680]\n", + "Saving current: saved/GRU4Rec-Mar-20-2022_02-28-47.pth\n", + "epoch 23 training [time: 25.80s, train loss: 5879.6025]\n", + "Saving current: saved/GRU4Rec-Mar-20-2022_02-28-47.pth\n", + "epoch 24 training [time: 25.52s, train loss: 5871.1030]\n", + "Saving current: saved/GRU4Rec-Mar-20-2022_02-28-47.pth\n", + "epoch 25 training [time: 25.84s, train loss: 5864.6343]\n", + "Saving current: saved/GRU4Rec-Mar-20-2022_02-28-47.pth\n", + "epoch 26 training [time: 25.54s, train loss: 5858.6181]\n", + "Saving current: saved/GRU4Rec-Mar-20-2022_02-28-47.pth\n", + "epoch 27 training [time: 25.94s, train loss: 5851.7828]\n", + "Saving current: saved/GRU4Rec-Mar-20-2022_02-28-47.pth\n", + "epoch 28 training [time: 25.66s, train loss: 5846.2326]\n", + "Saving current: saved/GRU4Rec-Mar-20-2022_02-28-47.pth\n", + "epoch 29 training [time: 25.61s, train loss: 5840.5342]\n", + "Saving current: saved/GRU4Rec-Mar-20-2022_02-28-47.pth\n", + "epoch 30 training [time: 25.72s, train loss: 5836.3762]\n", + "Saving current: saved/GRU4Rec-Mar-20-2022_02-28-47.pth\n", + "epoch 31 training [time: 25.81s, train loss: 5831.4665]\n", + "Saving current: saved/GRU4Rec-Mar-20-2022_02-28-47.pth\n", + "epoch 32 training [time: 25.65s, train loss: 5826.3515]\n", + "Saving current: saved/GRU4Rec-Mar-20-2022_02-28-47.pth\n", + "epoch 33 training [time: 26.00s, train loss: 5822.0054]\n", + "Saving current: saved/GRU4Rec-Mar-20-2022_02-28-47.pth\n", + "epoch 34 training [time: 25.27s, train loss: 5818.5471]\n", + "Saving current: saved/GRU4Rec-Mar-20-2022_02-28-47.pth\n", + "epoch 35 training [time: 25.82s, train loss: 5814.4312]\n", + "Saving current: saved/GRU4Rec-Mar-20-2022_02-28-47.pth\n", + "epoch 36 training [time: 25.81s, train loss: 5810.0819]\n", + "Saving current: saved/GRU4Rec-Mar-20-2022_02-28-47.pth\n", + "epoch 37 training [time: 25.82s, train loss: 5807.7266]\n", + "Saving current: saved/GRU4Rec-Mar-20-2022_02-28-47.pth\n", + "epoch 38 training [time: 25.45s, train loss: 5803.9869]\n", + "Saving current: saved/GRU4Rec-Mar-20-2022_02-28-47.pth\n", + "epoch 39 training [time: 26.11s, train loss: 5801.2985]\n", + "Saving current: saved/GRU4Rec-Mar-20-2022_02-28-47.pth\n", + "epoch 40 training [time: 25.86s, train loss: 5798.3783]\n", + "Saving current: saved/GRU4Rec-Mar-20-2022_02-28-47.pth\n", + "epoch 41 training [time: 26.02s, train loss: 5795.3381]\n", + "Saving current: saved/GRU4Rec-Mar-20-2022_02-28-47.pth\n", + "epoch 42 training [time: 26.13s, train loss: 5792.3586]\n", + "Saving current: saved/GRU4Rec-Mar-20-2022_02-28-47.pth\n", + "epoch 43 training [time: 25.52s, train loss: 5789.8296]\n", + "Saving current: saved/GRU4Rec-Mar-20-2022_02-28-47.pth\n", + "epoch 44 training [time: 25.86s, train loss: 5787.6979]\n", + "Saving current: saved/GRU4Rec-Mar-20-2022_02-28-47.pth\n", + "epoch 45 training [time: 26.05s, train loss: 5785.8937]\n", + "Saving current: saved/GRU4Rec-Mar-20-2022_02-28-47.pth\n", + "epoch 46 training [time: 25.78s, train loss: 5782.2815]\n", + "Saving current: saved/GRU4Rec-Mar-20-2022_02-28-47.pth\n", + "epoch 47 training [time: 25.79s, train loss: 5781.2301]\n", + "Saving current: saved/GRU4Rec-Mar-20-2022_02-28-47.pth\n", + "epoch 48 training [time: 25.52s, train loss: 5777.6824]\n", + "Saving current: saved/GRU4Rec-Mar-20-2022_02-28-47.pth\n", + "epoch 49 training [time: 25.87s, train loss: 5776.9169]\n", + "Saving current: saved/GRU4Rec-Mar-20-2022_02-28-47.pth\n" + ] + } + ], + "source": [ + "# model loading and initialization\n", + "model = GRU4Rec(config, train_data.dataset).to(config['device'])\n", + "logger.info(model)\n", + "\n", + "# trainer loading and initialization\n", + "trainer = Trainer(config, model)\n", + "\n", + "# model training\n", + "best_valid_score, best_valid_result = trainer.fit(train_data)" + ] + }, + { + "cell_type": "markdown", + "id": "da91afeb", + "metadata": { + "papermill": { + "duration": 0.386664, + "end_time": "2022-03-20T02:50:21.838866", + "exception": false, + "start_time": "2022-03-20T02:50:21.452202", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "# 4. Create recommendation result from trained model\n", + "\n", + "I note document here for any one want to customize it: https://recbole.io/docs/user_guide/usage/case_study.html" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "3303aa1c", + "metadata": { + "execution": { + "iopub.execute_input": "2022-03-20T02:50:22.633719Z", + "iopub.status.busy": "2022-03-20T02:50:22.632377Z", + "iopub.status.idle": "2022-03-20T02:50:22.639114Z", + "shell.execute_reply": "2022-03-20T02:50:22.639669Z", + "shell.execute_reply.started": "2022-03-19T09:13:42.897843Z" + }, + "papermill": { + "duration": 0.407741, + "end_time": "2022-03-20T02:50:22.639886", + "exception": false, + "start_time": "2022-03-20T02:50:22.232145", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "external_user_ids = dataset.id2token(\n", + " dataset.uid_field, list(range(dataset.user_num)))[1:]#fist element in array is 'PAD'(default of Recbole) ->remove it " + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "77a87d57", + "metadata": { + "execution": { + "iopub.execute_input": "2022-03-20T02:50:23.765193Z", + "iopub.status.busy": "2022-03-20T02:50:23.764128Z", + "iopub.status.idle": "2022-03-20T02:50:23.775557Z", + "shell.execute_reply": "2022-03-20T02:50:23.776880Z", + "shell.execute_reply.started": "2022-03-19T09:17:48.660391Z" + }, + "papermill": { + "duration": 0.703287, + "end_time": "2022-03-20T02:50:23.777164", + "exception": false, + "start_time": "2022-03-20T02:50:23.073877", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "import torch\n", + "from recbole.data.interaction import Interaction\n", + "\n", + "def add_last_item(old_interaction, last_item_id, max_len=50):\n", + " new_seq_items = old_interaction['item_id_list'][-1]\n", + " if old_interaction['item_length'][-1].item() < max_len:\n", + " new_seq_items[old_interaction['item_length'][-1].item()] = last_item_id\n", + " else:\n", + " new_seq_items = torch.roll(new_seq_items, -1)\n", + " new_seq_items[-1] = last_item_id\n", + " return new_seq_items.view(1, len(new_seq_items))\n", + "\n", + "def predict_for_all_item(external_user_id, dataset, model):\n", + " model.eval()\n", + " with torch.no_grad():\n", + " uid_series = dataset.token2id(dataset.uid_field, [external_user_id])\n", + " index = np.isin(dataset[dataset.uid_field].numpy(), uid_series)\n", + " input_interaction = dataset[index]\n", + " test = {\n", + " 'item_id_list': add_last_item(input_interaction, \n", + " input_interaction['item_id'][-1].item(), model.max_seq_length),\n", + " 'item_length': torch.tensor(\n", + " [input_interaction['item_length'][-1].item() + 1\n", + " if input_interaction['item_length'][-1].item() < model.max_seq_length else model.max_seq_length])\n", + " }\n", + " new_inter = Interaction(test)\n", + " new_inter = new_inter.to(config['device'])\n", + " new_scores = model.full_sort_predict(new_inter)\n", + " new_scores = new_scores.view(-1, test_data.dataset.item_num)\n", + " new_scores[:, 0] = -np.inf # set scores of [pad] to -inf\n", + " return torch.topk(new_scores, 10)" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "dacd5516", + "metadata": { + "execution": { + "iopub.execute_input": "2022-03-20T02:50:24.598660Z", + "iopub.status.busy": "2022-03-20T02:50:24.597549Z", + "iopub.status.idle": "2022-03-20T02:50:24.753648Z", + "shell.execute_reply": "2022-03-20T02:50:24.752884Z", + "shell.execute_reply.started": "2022-03-19T09:18:08.614523Z" + }, + "papermill": { + "duration": 0.545835, + "end_time": "2022-03-20T02:50:24.753849", + "exception": false, + "start_time": "2022-03-20T02:50:24.208014", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.return_types.topk(\n", + "values=tensor([[7.9712, 7.7557, 6.3152, 6.0824, 6.0296, 5.8736, 5.8550, 5.8297, 5.8106,\n", + " 5.7406]], device='cuda:0'),\n", + "indices=tensor([[6713, 6663, 8766, 496, 8749, 2763, 3097, 2117, 643, 2838]],\n", + " device='cuda:0'))" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "predict_for_all_item('0109ad0b5a76924a1b58be677409bb601cc8bead9a87b8ce5b08a4a1f5bc71ef', \n", + " dataset, model)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "bb8ef264", + "metadata": { + "execution": { + "iopub.execute_input": "2022-03-20T02:50:25.539828Z", + "iopub.status.busy": "2022-03-20T02:50:25.538457Z", + "iopub.status.idle": "2022-03-20T03:16:02.324679Z", + "shell.execute_reply": "2022-03-20T03:16:02.323774Z", + "shell.execute_reply.started": "2022-03-19T09:18:22.091273Z" + }, + "papermill": { + "duration": 1537.179191, + "end_time": "2022-03-20T03:16:02.324943", + "exception": false, + "start_time": "2022-03-20T02:50:25.145752", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "38915\n" + ] + } + ], + "source": [ + "topk_items = []\n", + "for external_user_id in external_user_ids:\n", + " _, topk_iid_list = predict_for_all_item(external_user_id, dataset, model)\n", + " last_topk_iid_list = topk_iid_list[-1]\n", + " external_item_list = dataset.id2token(dataset.iid_field, last_topk_iid_list.cpu()).tolist()\n", + " topk_items.append(external_item_list)\n", + "print(len(topk_items))" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "11a06a38", + "metadata": { + "execution": { + "iopub.execute_input": "2022-03-20T03:16:03.112776Z", + "iopub.status.busy": "2022-03-20T03:16:03.111650Z", + "iopub.status.idle": "2022-03-20T03:16:03.147289Z", + "shell.execute_reply": "2022-03-20T03:16:03.147869Z", + "shell.execute_reply.started": "2022-03-19T09:43:08.156694Z" + }, + "papermill": { + "duration": 0.436152, + "end_time": "2022-03-20T03:16:03.148114", + "exception": false, + "start_time": "2022-03-20T03:16:02.711962", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
customer_idprediction
00010e8eb18f131e724d6997909af0808adbba057529edb...0372860001 0706016003 0706016001 0610776002 08...
10064cd1ee810d4caabd1182a8f177479b82b18961bd76b...0894956001 0907527001 0905957001 0769748014 07...
200ce4f170d9fe36d0aacca94addfc3b07f70f81dc7bde3...0881244001 0867966009 0889652001 0750422039 07...
300d7ebd46f6a6d53630d41386b6ef6a505cdc4c80011ff...0918522001 0915526001 0751592001 0924243001 09...
400eebac2c2e37626461e74e8395711964c4e01a7afa643...0866731001 0875350003 0915526001 0933891001 08...
\n", + "
" + ], + "text/plain": [ + " customer_id \\\n", + "0 0010e8eb18f131e724d6997909af0808adbba057529edb... \n", + "1 0064cd1ee810d4caabd1182a8f177479b82b18961bd76b... \n", + "2 00ce4f170d9fe36d0aacca94addfc3b07f70f81dc7bde3... \n", + "3 00d7ebd46f6a6d53630d41386b6ef6a505cdc4c80011ff... \n", + "4 00eebac2c2e37626461e74e8395711964c4e01a7afa643... \n", + "\n", + " prediction \n", + "0 0372860001 0706016003 0706016001 0610776002 08... \n", + "1 0894956001 0907527001 0905957001 0769748014 07... \n", + "2 0881244001 0867966009 0889652001 0750422039 07... \n", + "3 0918522001 0915526001 0751592001 0924243001 09... \n", + "4 0866731001 0875350003 0915526001 0933891001 08... " + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "external_item_str = [' '.join(x) for x in topk_items]\n", + "result = pd.DataFrame(external_user_ids, columns=['customer_id'])\n", + "result['prediction'] = external_item_str\n", + "result.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "89ea3162", + "metadata": { + "execution": { + "iopub.execute_input": "2022-03-20T03:16:04.533514Z", + "iopub.status.busy": "2022-03-20T03:16:04.531232Z", + "iopub.status.idle": "2022-03-20T03:16:04.536176Z", + "shell.execute_reply": "2022-03-20T03:16:04.535590Z", + "shell.execute_reply.started": "2022-03-19T09:45:15.244813Z" + }, + "papermill": { + "duration": 0.668816, + "end_time": "2022-03-20T03:16:04.536353", + "exception": false, + "start_time": "2022-03-20T03:16:03.867537", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "21" + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "del external_item_str\n", + "del topk_items\n", + "del external_user_ids\n", + "del train_data\n", + "del valid_data\n", + "del test_data\n", + "del model\n", + "del Trainer\n", + "del logger\n", + "gc.collect()" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "5d6fe3a8", + "metadata": { + "execution": { + "iopub.execute_input": "2022-03-20T03:16:05.551146Z", + "iopub.status.busy": "2022-03-20T03:16:05.370774Z", + "iopub.status.idle": "2022-03-20T03:16:05.554678Z", + "shell.execute_reply": "2022-03-20T03:16:05.555311Z", + "shell.execute_reply.started": "2022-03-19T09:46:46.555935Z" + }, + "papermill": { + "duration": 0.635722, + "end_time": "2022-03-20T03:16:05.555501", + "exception": false, + "start_time": "2022-03-20T03:16:04.919779", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "21" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "del dataset\n", + "gc.collect()" + ] + }, + { + "cell_type": "markdown", + "id": "8906703f", + "metadata": { + "papermill": { + "duration": 0.392395, + "end_time": "2022-03-20T03:16:06.331419", + "exception": false, + "start_time": "2022-03-20T03:16:05.939024", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "# 5. Combine result from most bought items and GRU model" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "27c4d836", + "metadata": { + "execution": { + "iopub.execute_input": "2022-03-20T03:16:07.110312Z", + "iopub.status.busy": "2022-03-20T03:16:07.109301Z", + "iopub.status.idle": "2022-03-20T03:16:10.656448Z", + "shell.execute_reply": "2022-03-20T03:16:10.655419Z", + "shell.execute_reply.started": "2022-03-19T09:46:59.90542Z" + }, + "papermill": { + "duration": 3.936122, + "end_time": "2022-03-20T03:16:10.656605", + "exception": false, + "start_time": "2022-03-20T03:16:06.720483", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "(1371980, 2)" + ] + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "submit_df = pd.read_csv('submission.csv')\n", + "submit_df.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "d12785d1", + "metadata": { + "execution": { + "iopub.execute_input": "2022-03-20T03:16:11.465716Z", + "iopub.status.busy": "2022-03-20T03:16:11.464599Z", + "iopub.status.idle": "2022-03-20T03:16:11.468743Z", + "shell.execute_reply": "2022-03-20T03:16:11.469366Z", + "shell.execute_reply.started": "2022-03-19T09:47:03.369678Z" + }, + "papermill": { + "duration": 0.422846, + "end_time": "2022-03-20T03:16:11.469547", + "exception": false, + "start_time": "2022-03-20T03:16:11.046701", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
customer_idprediction
000000dbacae5abe5e23885899a1fa44253a17956c6d1c3...0568601043 0568601006 0656719005 0745232001 09...
10000423b00ade91418cceaf3b26c6af3dd342b51fd051e...0826211002 0800436010 0924243001 0739590027 07...
2000058a12d5b43e67d225668fa1f8d618c13dc232df0ca...0794321007 0852643001 0852643003 0858883002 09...
300005ca1c9ed5f5146b52ac8639a40ca9d57aeff4d1bd2...0448509014 0573085028 0924243001 0751471001 07...
400006413d8573cd20ed7128e53b7b13819fe5cfc2d801f...0730683050 0791587015 0924243001 0896152002 08...
\n", + "
" + ], + "text/plain": [ + " customer_id \\\n", + "0 00000dbacae5abe5e23885899a1fa44253a17956c6d1c3... \n", + "1 0000423b00ade91418cceaf3b26c6af3dd342b51fd051e... \n", + "2 000058a12d5b43e67d225668fa1f8d618c13dc232df0ca... \n", + "3 00005ca1c9ed5f5146b52ac8639a40ca9d57aeff4d1bd2... \n", + "4 00006413d8573cd20ed7128e53b7b13819fe5cfc2d801f... \n", + "\n", + " prediction \n", + "0 0568601043 0568601006 0656719005 0745232001 09... \n", + "1 0826211002 0800436010 0924243001 0739590027 07... \n", + "2 0794321007 0852643001 0852643003 0858883002 09... \n", + "3 0448509014 0573085028 0924243001 0751471001 07... \n", + "4 0730683050 0791587015 0924243001 0896152002 08... " + ] + }, + "execution_count": 26, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "submit_df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "102363f9", + "metadata": { + "execution": { + "iopub.execute_input": "2022-03-20T03:16:12.249718Z", + "iopub.status.busy": "2022-03-20T03:16:12.248055Z", + "iopub.status.idle": "2022-03-20T03:16:13.361191Z", + "shell.execute_reply": "2022-03-20T03:16:13.360487Z", + "shell.execute_reply.started": "2022-03-19T09:47:03.384134Z" + }, + "papermill": { + "duration": 1.505033, + "end_time": "2022-03-20T03:16:13.361342", + "exception": false, + "start_time": "2022-03-20T03:16:11.856309", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
customer_idprediction_xprediction_y
000000dbacae5abe5e23885899a1fa44253a17956c6d1c3...0568601043 0568601006 0656719005 0745232001 09...NaN
10000423b00ade91418cceaf3b26c6af3dd342b51fd051e...0826211002 0800436010 0924243001 0739590027 07...NaN
2000058a12d5b43e67d225668fa1f8d618c13dc232df0ca...0794321007 0852643001 0852643003 0858883002 09...NaN
300005ca1c9ed5f5146b52ac8639a40ca9d57aeff4d1bd2...0448509014 0573085028 0924243001 0751471001 07...NaN
400006413d8573cd20ed7128e53b7b13819fe5cfc2d801f...0730683050 0791587015 0924243001 0896152002 08...NaN
\n", + "
" + ], + "text/plain": [ + " customer_id \\\n", + "0 00000dbacae5abe5e23885899a1fa44253a17956c6d1c3... \n", + "1 0000423b00ade91418cceaf3b26c6af3dd342b51fd051e... \n", + "2 000058a12d5b43e67d225668fa1f8d618c13dc232df0ca... \n", + "3 00005ca1c9ed5f5146b52ac8639a40ca9d57aeff4d1bd2... \n", + "4 00006413d8573cd20ed7128e53b7b13819fe5cfc2d801f... \n", + "\n", + " prediction_x prediction_y \n", + "0 0568601043 0568601006 0656719005 0745232001 09... NaN \n", + "1 0826211002 0800436010 0924243001 0739590027 07... NaN \n", + "2 0794321007 0852643001 0852643003 0858883002 09... NaN \n", + "3 0448509014 0573085028 0924243001 0751471001 07... NaN \n", + "4 0730683050 0791587015 0924243001 0896152002 08... NaN " + ] + }, + "execution_count": 27, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "submit_df = pd.merge(submit_df, result, on='customer_id', how='outer')\n", + "submit_df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "ea585e0c", + "metadata": { + "execution": { + "iopub.execute_input": "2022-03-20T03:16:14.288474Z", + "iopub.status.busy": "2022-03-20T03:16:14.287176Z", + "iopub.status.idle": "2022-03-20T03:16:44.750560Z", + "shell.execute_reply": "2022-03-20T03:16:44.751188Z", + "shell.execute_reply.started": "2022-03-19T09:47:04.438889Z" + }, + "papermill": { + "duration": 30.99182, + "end_time": "2022-03-20T03:16:44.751372", + "exception": false, + "start_time": "2022-03-20T03:16:13.759552", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
customer_idprediction_xprediction_yprediction
000000dbacae5abe5e23885899a1fa44253a17956c6d1c3...0568601043 0568601006 0656719005 0745232001 09...-10568601043 0568601006 0656719005 0745232001 09...
10000423b00ade91418cceaf3b26c6af3dd342b51fd051e...0826211002 0800436010 0924243001 0739590027 07...-10826211002 0800436010 0924243001 0739590027 07...
2000058a12d5b43e67d225668fa1f8d618c13dc232df0ca...0794321007 0852643001 0852643003 0858883002 09...-10794321007 0852643001 0852643003 0858883002 09...
300005ca1c9ed5f5146b52ac8639a40ca9d57aeff4d1bd2...0448509014 0573085028 0924243001 0751471001 07...-10448509014 0573085028 0924243001 0751471001 07...
400006413d8573cd20ed7128e53b7b13819fe5cfc2d801f...0730683050 0791587015 0924243001 0896152002 08...-10730683050 0791587015 0924243001 0896152002 08...
\n", + "
" + ], + "text/plain": [ + " customer_id \\\n", + "0 00000dbacae5abe5e23885899a1fa44253a17956c6d1c3... \n", + "1 0000423b00ade91418cceaf3b26c6af3dd342b51fd051e... \n", + "2 000058a12d5b43e67d225668fa1f8d618c13dc232df0ca... \n", + "3 00005ca1c9ed5f5146b52ac8639a40ca9d57aeff4d1bd2... \n", + "4 00006413d8573cd20ed7128e53b7b13819fe5cfc2d801f... \n", + "\n", + " prediction_x prediction_y \\\n", + "0 0568601043 0568601006 0656719005 0745232001 09... -1 \n", + "1 0826211002 0800436010 0924243001 0739590027 07... -1 \n", + "2 0794321007 0852643001 0852643003 0858883002 09... -1 \n", + "3 0448509014 0573085028 0924243001 0751471001 07... -1 \n", + "4 0730683050 0791587015 0924243001 0896152002 08... -1 \n", + "\n", + " prediction \n", + "0 0568601043 0568601006 0656719005 0745232001 09... \n", + "1 0826211002 0800436010 0924243001 0739590027 07... \n", + "2 0794321007 0852643001 0852643003 0858883002 09... \n", + "3 0448509014 0573085028 0924243001 0751471001 07... \n", + "4 0730683050 0791587015 0924243001 0896152002 08... " + ] + }, + "execution_count": 28, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "submit_df = submit_df.fillna(-1)\n", + "submit_df['prediction'] = submit_df.apply(\n", + " lambda x: x['prediction_y'] if x['prediction_y'] != -1 else x['prediction_x'], axis=1)\n", + "submit_df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "id": "754283af", + "metadata": { + "execution": { + "iopub.execute_input": "2022-03-20T03:16:45.707218Z", + "iopub.status.busy": "2022-03-20T03:16:45.705648Z", + "iopub.status.idle": "2022-03-20T03:16:45.923952Z", + "shell.execute_reply": "2022-03-20T03:16:45.924550Z", + "shell.execute_reply.started": "2022-03-19T09:47:34.25912Z" + }, + "papermill": { + "duration": 0.781266, + "end_time": "2022-03-20T03:16:45.924737", + "exception": false, + "start_time": "2022-03-20T03:16:45.143471", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
customer_idprediction_xprediction_yprediction
1300009d946eec3ea54add5ba56d5210ea898def4b46c685...0891899004 0562245099 0797892001 0516859008 07...0568597007 0568601007 0831450002 0881244001 05...0568597007 0568601007 0831450002 0881244001 05...
380001d44dbe7f6c4b35200abdb052c77a87596fe1bdcc37...0734592001 0888024005 0572998013 0909869004 08...0891591001 0933706001 0919499007 0911214001 09...0891591001 0933706001 0919499007 0911214001 09...
1690006d3ff0caf0cb4d4e0615ee5cb7d268622364d483335...0930829001 0915529001 0870525005 0751471041 05...0884319006 0928088001 0915529001 0832307007 09...0884319006 0928088001 0915529001 0832307007 09...
17500075ef36696a7b4ed8c83e22a4bf7ea7c90ee110991ec...0860285001 0863595006 0824526004 0751471022 08...0824526004 0874819002 0893059005 0893059004 08...0824526004 0874819002 0893059005 0893059004 08...
19500080403a669b3b89d1bef1ec73ea466d95e39698d6dde...0825771007 0784053005 0924243001 0914886003 08...0914319001 0876147001 0868038003 0914319002 08...0914319001 0876147001 0868038003 0914319002 08...
...............
1371778fff624f63f0279200646a4f8bf27e5150096212d50fdd0...0399256001 0873045001 0842755001 0885870002 08...0842755001 0865917002 0869397001 0811900002 08...0842755001 0865917002 0869397001 0811900002 08...
1371787fff673307d4cdbf688e4a0bcfe7f671036033dbe7eba01...0865086004 0881916001 0711053003 0791587015 07...0865086004 0754238024 0894956001 0852584001 09...0865086004 0754238024 0894956001 0852584001 09...
1371876fffabaebcc10efa0e613b58de37901e04fa25a2f90a0a8...0652924004 0894400002 0924243001 0573937001 08...0756904015 0854777001 0739533002 0844874012 05...0756904015 0854777001 0739533002 0844874012 05...
1371879fffae8eb3a282d8c43c77dd2ca0621703b71e90904dfde...0865624003 0396135007 0797892001 0817472007 07...0729928025 0817472004 0729928001 0865624003 08...0729928025 0817472004 0729928001 0865624003 08...
1371960fffef3b6b73545df065b521e19f64bf6fe93bfd450ab20...0898573003 0748269009 0905365002 0881919001 08...0898573003 0762796013 0893141002 0915529005 08...0898573003 0762796013 0893141002 0915529005 08...
\n", + "

38915 rows × 4 columns

\n", + "
" + ], + "text/plain": [ + " customer_id \\\n", + "13 00009d946eec3ea54add5ba56d5210ea898def4b46c685... \n", + "38 0001d44dbe7f6c4b35200abdb052c77a87596fe1bdcc37... \n", + "169 0006d3ff0caf0cb4d4e0615ee5cb7d268622364d483335... \n", + "175 00075ef36696a7b4ed8c83e22a4bf7ea7c90ee110991ec... \n", + "195 00080403a669b3b89d1bef1ec73ea466d95e39698d6dde... \n", + "... ... \n", + "1371778 fff624f63f0279200646a4f8bf27e5150096212d50fdd0... \n", + "1371787 fff673307d4cdbf688e4a0bcfe7f671036033dbe7eba01... \n", + "1371876 fffabaebcc10efa0e613b58de37901e04fa25a2f90a0a8... \n", + "1371879 fffae8eb3a282d8c43c77dd2ca0621703b71e90904dfde... \n", + "1371960 fffef3b6b73545df065b521e19f64bf6fe93bfd450ab20... \n", + "\n", + " prediction_x \\\n", + "13 0891899004 0562245099 0797892001 0516859008 07... \n", + "38 0734592001 0888024005 0572998013 0909869004 08... \n", + "169 0930829001 0915529001 0870525005 0751471041 05... \n", + "175 0860285001 0863595006 0824526004 0751471022 08... \n", + "195 0825771007 0784053005 0924243001 0914886003 08... \n", + "... ... \n", + "1371778 0399256001 0873045001 0842755001 0885870002 08... \n", + "1371787 0865086004 0881916001 0711053003 0791587015 07... \n", + "1371876 0652924004 0894400002 0924243001 0573937001 08... \n", + "1371879 0865624003 0396135007 0797892001 0817472007 07... \n", + "1371960 0898573003 0748269009 0905365002 0881919001 08... \n", + "\n", + " prediction_y \\\n", + "13 0568597007 0568601007 0831450002 0881244001 05... \n", + "38 0891591001 0933706001 0919499007 0911214001 09... \n", + "169 0884319006 0928088001 0915529001 0832307007 09... \n", + "175 0824526004 0874819002 0893059005 0893059004 08... \n", + "195 0914319001 0876147001 0868038003 0914319002 08... \n", + "... ... \n", + "1371778 0842755001 0865917002 0869397001 0811900002 08... \n", + "1371787 0865086004 0754238024 0894956001 0852584001 09... \n", + "1371876 0756904015 0854777001 0739533002 0844874012 05... \n", + "1371879 0729928025 0817472004 0729928001 0865624003 08... \n", + "1371960 0898573003 0762796013 0893141002 0915529005 08... \n", + "\n", + " prediction \n", + "13 0568597007 0568601007 0831450002 0881244001 05... \n", + "38 0891591001 0933706001 0919499007 0911214001 09... \n", + "169 0884319006 0928088001 0915529001 0832307007 09... \n", + "175 0824526004 0874819002 0893059005 0893059004 08... \n", + "195 0914319001 0876147001 0868038003 0914319002 08... \n", + "... ... \n", + "1371778 0842755001 0865917002 0869397001 0811900002 08... \n", + "1371787 0865086004 0754238024 0894956001 0852584001 09... \n", + "1371876 0756904015 0854777001 0739533002 0844874012 05... \n", + "1371879 0729928025 0817472004 0729928001 0865624003 08... \n", + "1371960 0898573003 0762796013 0893141002 0915529005 08... \n", + "\n", + "[38915 rows x 4 columns]" + ] + }, + "execution_count": 29, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "submit_df[submit_df['prediction_y'] != -1]" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "id": "223a1931", + "metadata": { + "execution": { + "iopub.execute_input": "2022-03-20T03:16:46.723462Z", + "iopub.status.busy": "2022-03-20T03:16:46.722014Z", + "iopub.status.idle": "2022-03-20T03:16:46.839745Z", + "shell.execute_reply": "2022-03-20T03:16:46.839187Z", + "shell.execute_reply.started": "2022-03-19T09:47:34.685876Z" + }, + "papermill": { + "duration": 0.520431, + "end_time": "2022-03-20T03:16:46.839926", + "exception": false, + "start_time": "2022-03-20T03:16:46.319495", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
customer_idprediction
000000dbacae5abe5e23885899a1fa44253a17956c6d1c3...0568601043 0568601006 0656719005 0745232001 09...
10000423b00ade91418cceaf3b26c6af3dd342b51fd051e...0826211002 0800436010 0924243001 0739590027 07...
2000058a12d5b43e67d225668fa1f8d618c13dc232df0ca...0794321007 0852643001 0852643003 0858883002 09...
300005ca1c9ed5f5146b52ac8639a40ca9d57aeff4d1bd2...0448509014 0573085028 0924243001 0751471001 07...
400006413d8573cd20ed7128e53b7b13819fe5cfc2d801f...0730683050 0791587015 0924243001 0896152002 08...
\n", + "
" + ], + "text/plain": [ + " customer_id \\\n", + "0 00000dbacae5abe5e23885899a1fa44253a17956c6d1c3... \n", + "1 0000423b00ade91418cceaf3b26c6af3dd342b51fd051e... \n", + "2 000058a12d5b43e67d225668fa1f8d618c13dc232df0ca... \n", + "3 00005ca1c9ed5f5146b52ac8639a40ca9d57aeff4d1bd2... \n", + "4 00006413d8573cd20ed7128e53b7b13819fe5cfc2d801f... \n", + "\n", + " prediction \n", + "0 0568601043 0568601006 0656719005 0745232001 09... \n", + "1 0826211002 0800436010 0924243001 0739590027 07... \n", + "2 0794321007 0852643001 0852643003 0858883002 09... \n", + "3 0448509014 0573085028 0924243001 0751471001 07... \n", + "4 0730683050 0791587015 0924243001 0896152002 08... " + ] + }, + "execution_count": 30, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "submit_df = submit_df.drop(columns=['prediction_y', 'prediction_x'])\n", + "submit_df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "id": "7a361474", + "metadata": { + "execution": { + "iopub.execute_input": "2022-03-20T03:16:47.973860Z", + "iopub.status.busy": "2022-03-20T03:16:47.972808Z", + "iopub.status.idle": "2022-03-20T03:16:59.835508Z", + "shell.execute_reply": "2022-03-20T03:16:59.836813Z", + "shell.execute_reply.started": "2022-03-19T09:47:34.809425Z" + }, + "papermill": { + "duration": 12.508319, + "end_time": "2022-03-20T03:16:59.837107", + "exception": false, + "start_time": "2022-03-20T03:16:47.328788", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "submit_df.to_csv('submission.csv', index=False)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b0fad61d", + "metadata": { + "papermill": { + "duration": 0.414141, + "end_time": "2022-03-20T03:17:00.840332", + "exception": false, + "start_time": "2022-03-20T03:17:00.426191", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.12" + }, + "papermill": { + "default_parameters": {}, + "duration": 3480.556455, + "end_time": "2022-03-20T03:17:04.204601", + "environment_variables": {}, + "exception": null, + "input_path": "__notebook__.ipynb", + "output_path": "__notebook__.ipynb", + "parameters": {}, + "start_time": "2022-03-20T02:19:03.648146", + "version": "2.3.3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} From a615d4d081bfd7b983ba545b5c41b0f030904d8b Mon Sep 17 00:00:00 2001 From: Sherry-XLL Date: Mon, 4 Apr 2022 04:56:27 +0000 Subject: [PATCH 2/5] FEA: add source link and author of ipynb --- .../lstm-model-with-item-infor-fix-missing-last-item.ipynb | 3 +++ run_example/recbole-using-all-items-for-prediction.ipynb | 3 +++ run_example/sequential-model-fixed-missing-last-item.ipynb | 3 +++ 3 files changed, 9 insertions(+) diff --git a/run_example/lstm-model-with-item-infor-fix-missing-last-item.ipynb b/run_example/lstm-model-with-item-infor-fix-missing-last-item.ipynb index 84cfd2125..039d18cd9 100644 --- a/run_example/lstm-model-with-item-infor-fix-missing-last-item.ipynb +++ b/run_example/lstm-model-with-item-infor-fix-missing-last-item.ipynb @@ -15,6 +15,9 @@ }, "source": [ "# 0.Overview\n", + "This is a notebook tutorial about **lstm-model-with-item-infor-fix-missing-last-item** created by [astrung](https://github.com/astrung), and the original link is [here](https://www.kaggle.com/code/astrung/lstm-model-with-item-infor-fix-missing-last-item).\n", + "\n", + "- - -\n", "\n", "**Edit:**\n", "\n", diff --git a/run_example/recbole-using-all-items-for-prediction.ipynb b/run_example/recbole-using-all-items-for-prediction.ipynb index db9ce3d5f..fc454e3be 100644 --- a/run_example/recbole-using-all-items-for-prediction.ipynb +++ b/run_example/recbole-using-all-items-for-prediction.ipynb @@ -14,6 +14,9 @@ "tags": [] }, "source": [ + "This is a notebook tutorial about **recbole-using-all-items-for-prediction** created by [astrung](https://github.com/astrung), and the original link is [here](https://www.kaggle.com/code/astrung/recbole-using-all-items-for-prediction/notebook).\n", + "\n", + "- - -\n", "**Edit**: \n", "I have create new notebooks for applying our customize function for using all items as input for recommendation:\n", "* Using only interactions: https://www.kaggle.com/astrung/sequential-model-fixed-missing-last-item\n", diff --git a/run_example/sequential-model-fixed-missing-last-item.ipynb b/run_example/sequential-model-fixed-missing-last-item.ipynb index 4cb98d3c8..f9a21de21 100644 --- a/run_example/sequential-model-fixed-missing-last-item.ipynb +++ b/run_example/sequential-model-fixed-missing-last-item.ipynb @@ -15,6 +15,9 @@ }, "source": [ "# 0.Overview\n", + "This is a notebook tutorial about **sequential-model-fixed-missing-last-item** created by [astrung](https://github.com/astrung), and the original link is [here](https://www.kaggle.com/code/astrung/lstm-model-with-item-infor-fix-missing-last-item).\n", + "\n", + "- - -\n", "**Edit**:\n", "* In my previous notebooks([here](https://www.kaggle.com/code/astrung/lstm-sequential-modelwith-item-features-tutorial) and [here](https://www.kaggle.com/code/astrung/lstm-sequential-modelwith-item-features-tutorial)), we have used test_data with `full_sort_topk`,but due to the limit of full_sort_topk we have missed last item for submited recommendation. Someone asked me about how can use all items as input features for recommendation in this [comment](https://www.kaggle.com/code/astrung/recbole-lstm-sequential-for-recomendation-tutorial/comments#1723707). \n", "* So i created a notebook [here](https://www.kaggle.com/code/astrung/recbole-using-all-items-for-prediction) for address there questions in detail, and this notebook is an improved of my [previous notebook](https://www.kaggle.com/code/astrung/lstm-sequential-modelwith-item-features-tutorial), applying our new function (using all item as input features without `full_sort_topk`) for this competition.\n", From 3e566315301f5bd2cc6a8ce83c0db57e94e1fc6e Mon Sep 17 00:00:00 2001 From: Sherry-XLL Date: Mon, 4 Apr 2022 05:04:00 +0000 Subject: [PATCH 3/5] FIX: update author and link of ipynb --- .../lstm-model-with-item-infor-fix-missing-last-item.ipynb | 7 +++---- run_example/recbole-using-all-items-for-prediction.ipynb | 6 +++--- run_example/sequential-model-fixed-missing-last-item.ipynb | 6 +++--- 3 files changed, 9 insertions(+), 10 deletions(-) diff --git a/run_example/lstm-model-with-item-infor-fix-missing-last-item.ipynb b/run_example/lstm-model-with-item-infor-fix-missing-last-item.ipynb index 039d18cd9..7dfe65269 100644 --- a/run_example/lstm-model-with-item-infor-fix-missing-last-item.ipynb +++ b/run_example/lstm-model-with-item-infor-fix-missing-last-item.ipynb @@ -15,10 +15,9 @@ }, "source": [ "# 0.Overview\n", - "This is a notebook tutorial about **lstm-model-with-item-infor-fix-missing-last-item** created by [astrung](https://github.com/astrung), and the original link is [here](https://www.kaggle.com/code/astrung/lstm-model-with-item-infor-fix-missing-last-item).\n", - "\n", - "- - -\n", - "\n", + "**Tutorial**: lstm-model-with-item-infor-fix-missing-last-item\n", + "**Author**: [astrung](https://github.com/astrung)\n", + "**Original link**: [notebook](https://www.kaggle.com/code/astrung/lstm-model-with-item-infor-fix-missing-last-item)\n", "**Edit:**\n", "\n", "* In my previous notebooks([here](https://www.kaggle.com/code/astrung/lstm-sequential-modelwith-item-features-tutorial) and [here](https://www.kaggle.com/code/astrung/lstm-sequential-modelwith-item-features-tutorial)), we have used test_data with `full_sort_topk`,but due to the limit of full_sort_topk we have missed last item for submited recommendation. Someone asked me about how can use all items as input features for recommendation in this [comment](https://www.kaggle.com/code/astrung/recbole-lstm-sequential-for-recomendation-tutorial/comments#1723707). \n", diff --git a/run_example/recbole-using-all-items-for-prediction.ipynb b/run_example/recbole-using-all-items-for-prediction.ipynb index fc454e3be..7e8705813 100644 --- a/run_example/recbole-using-all-items-for-prediction.ipynb +++ b/run_example/recbole-using-all-items-for-prediction.ipynb @@ -14,9 +14,9 @@ "tags": [] }, "source": [ - "This is a notebook tutorial about **recbole-using-all-items-for-prediction** created by [astrung](https://github.com/astrung), and the original link is [here](https://www.kaggle.com/code/astrung/recbole-using-all-items-for-prediction/notebook).\n", - "\n", - "- - -\n", + "**Tutorial**: recbole-using-all-items-for-prediction\n", + "**Author**: [astrung](https://github.com/astrung)\n", + "**Original link**: [notebook](https://www.kaggle.com/code/astrung/recbole-using-all-items-for-prediction/notebook)\n", "**Edit**: \n", "I have create new notebooks for applying our customize function for using all items as input for recommendation:\n", "* Using only interactions: https://www.kaggle.com/astrung/sequential-model-fixed-missing-last-item\n", diff --git a/run_example/sequential-model-fixed-missing-last-item.ipynb b/run_example/sequential-model-fixed-missing-last-item.ipynb index f9a21de21..ecf46c1fa 100644 --- a/run_example/sequential-model-fixed-missing-last-item.ipynb +++ b/run_example/sequential-model-fixed-missing-last-item.ipynb @@ -15,9 +15,9 @@ }, "source": [ "# 0.Overview\n", - "This is a notebook tutorial about **sequential-model-fixed-missing-last-item** created by [astrung](https://github.com/astrung), and the original link is [here](https://www.kaggle.com/code/astrung/lstm-model-with-item-infor-fix-missing-last-item).\n", - "\n", - "- - -\n", + "**Tutorial**: sequential-model-fixed-missing-last-item\n", + "**Author**: [astrung](https://github.com/astrung)\n", + "**Original link**: [notebook](https://www.kaggle.com/code/astrung/sequential-model-fixed-missing-last-item)\n", "**Edit**:\n", "* In my previous notebooks([here](https://www.kaggle.com/code/astrung/lstm-sequential-modelwith-item-features-tutorial) and [here](https://www.kaggle.com/code/astrung/lstm-sequential-modelwith-item-features-tutorial)), we have used test_data with `full_sort_topk`,but due to the limit of full_sort_topk we have missed last item for submited recommendation. Someone asked me about how can use all items as input features for recommendation in this [comment](https://www.kaggle.com/code/astrung/recbole-lstm-sequential-for-recomendation-tutorial/comments#1723707). \n", "* So i created a notebook [here](https://www.kaggle.com/code/astrung/recbole-using-all-items-for-prediction) for address there questions in detail, and this notebook is an improved of my [previous notebook](https://www.kaggle.com/code/astrung/lstm-sequential-modelwith-item-features-tutorial), applying our new function (using all item as input features without `full_sort_topk`) for this competition.\n", From 5a81a934719aac90d532893afa0b8afc4c86ab2c Mon Sep 17 00:00:00 2001 From: Sherry-XLL Date: Mon, 4 Apr 2022 05:05:39 +0000 Subject: [PATCH 4/5] FIX: update author and link of ipynb --- .../lstm-model-with-item-infor-fix-missing-last-item.ipynb | 3 +++ run_example/recbole-using-all-items-for-prediction.ipynb | 3 +++ run_example/sequential-model-fixed-missing-last-item.ipynb | 3 +++ 3 files changed, 9 insertions(+) diff --git a/run_example/lstm-model-with-item-infor-fix-missing-last-item.ipynb b/run_example/lstm-model-with-item-infor-fix-missing-last-item.ipynb index 7dfe65269..21bbeced2 100644 --- a/run_example/lstm-model-with-item-infor-fix-missing-last-item.ipynb +++ b/run_example/lstm-model-with-item-infor-fix-missing-last-item.ipynb @@ -16,8 +16,11 @@ "source": [ "# 0.Overview\n", "**Tutorial**: lstm-model-with-item-infor-fix-missing-last-item\n", + "\n", "**Author**: [astrung](https://github.com/astrung)\n", + "\n", "**Original link**: [notebook](https://www.kaggle.com/code/astrung/lstm-model-with-item-infor-fix-missing-last-item)\n", + "\n", "**Edit:**\n", "\n", "* In my previous notebooks([here](https://www.kaggle.com/code/astrung/lstm-sequential-modelwith-item-features-tutorial) and [here](https://www.kaggle.com/code/astrung/lstm-sequential-modelwith-item-features-tutorial)), we have used test_data with `full_sort_topk`,but due to the limit of full_sort_topk we have missed last item for submited recommendation. Someone asked me about how can use all items as input features for recommendation in this [comment](https://www.kaggle.com/code/astrung/recbole-lstm-sequential-for-recomendation-tutorial/comments#1723707). \n", diff --git a/run_example/recbole-using-all-items-for-prediction.ipynb b/run_example/recbole-using-all-items-for-prediction.ipynb index 7e8705813..dc07271e2 100644 --- a/run_example/recbole-using-all-items-for-prediction.ipynb +++ b/run_example/recbole-using-all-items-for-prediction.ipynb @@ -15,8 +15,11 @@ }, "source": [ "**Tutorial**: recbole-using-all-items-for-prediction\n", + "\n", "**Author**: [astrung](https://github.com/astrung)\n", + "\n", "**Original link**: [notebook](https://www.kaggle.com/code/astrung/recbole-using-all-items-for-prediction/notebook)\n", + "\n", "**Edit**: \n", "I have create new notebooks for applying our customize function for using all items as input for recommendation:\n", "* Using only interactions: https://www.kaggle.com/astrung/sequential-model-fixed-missing-last-item\n", diff --git a/run_example/sequential-model-fixed-missing-last-item.ipynb b/run_example/sequential-model-fixed-missing-last-item.ipynb index ecf46c1fa..ec9798f0a 100644 --- a/run_example/sequential-model-fixed-missing-last-item.ipynb +++ b/run_example/sequential-model-fixed-missing-last-item.ipynb @@ -16,8 +16,11 @@ "source": [ "# 0.Overview\n", "**Tutorial**: sequential-model-fixed-missing-last-item\n", + "\n", "**Author**: [astrung](https://github.com/astrung)\n", + "\n", "**Original link**: [notebook](https://www.kaggle.com/code/astrung/sequential-model-fixed-missing-last-item)\n", + "\n", "**Edit**:\n", "* In my previous notebooks([here](https://www.kaggle.com/code/astrung/lstm-sequential-modelwith-item-features-tutorial) and [here](https://www.kaggle.com/code/astrung/lstm-sequential-modelwith-item-features-tutorial)), we have used test_data with `full_sort_topk`,but due to the limit of full_sort_topk we have missed last item for submited recommendation. Someone asked me about how can use all items as input features for recommendation in this [comment](https://www.kaggle.com/code/astrung/recbole-lstm-sequential-for-recomendation-tutorial/comments#1723707). \n", "* So i created a notebook [here](https://www.kaggle.com/code/astrung/recbole-using-all-items-for-prediction) for address there questions in detail, and this notebook is an improved of my [previous notebook](https://www.kaggle.com/code/astrung/lstm-sequential-modelwith-item-features-tutorial), applying our new function (using all item as input features without `full_sort_topk`) for this competition.\n", From 44232c160ca3c18a6f597099242bcfa885e156e8 Mon Sep 17 00:00:00 2001 From: Sherry-XLL Date: Mon, 4 Apr 2022 05:09:50 +0000 Subject: [PATCH 5/5] FEA: Add overview of recbole-using-all-items-for-prediction --- run_example/recbole-using-all-items-for-prediction.ipynb | 1 + 1 file changed, 1 insertion(+) diff --git a/run_example/recbole-using-all-items-for-prediction.ipynb b/run_example/recbole-using-all-items-for-prediction.ipynb index dc07271e2..33eca3bba 100644 --- a/run_example/recbole-using-all-items-for-prediction.ipynb +++ b/run_example/recbole-using-all-items-for-prediction.ipynb @@ -14,6 +14,7 @@ "tags": [] }, "source": [ + "# Overview\n", "**Tutorial**: recbole-using-all-items-for-prediction\n", "\n", "**Author**: [astrung](https://github.com/astrung)\n",