From d0b0557acf5eda3deb2b98aaaf35fc59e56110d0 Mon Sep 17 00:00:00 2001 From: Jeong-Yoon Lee Date: Tue, 7 Sep 2021 00:14:26 -0700 Subject: [PATCH] add 07-backprob.ipynb --- notebooks/04-numpy-lr.ipynb | 9 +- notebooks/07-backprop.ipynb | 581 ++++++++++++++++++++++++++++++++++++ 2 files changed, 589 insertions(+), 1 deletion(-) create mode 100644 notebooks/07-backprop.ipynb diff --git a/notebooks/04-numpy-lr.ipynb b/notebooks/04-numpy-lr.ipynb index 6664e11..faf4974 100644 --- a/notebooks/04-numpy-lr.ipynb +++ b/notebooks/04-numpy-lr.ipynb @@ -76,6 +76,13 @@ "## 학습데이터 로드" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "데이터는 [Dacon 단국대 소/중 데이터 분석 AI 경진대회 웹사이트](https://www.dacon.io/competitions/official/235638/data/)에서 다운로드 받아 `../input` 폴더에 저장." + ] + }, { "cell_type": "code", "execution_count": 4, @@ -963,7 +970,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.6" + "version": "3.8.5" }, "toc": { "base_numbering": 1, diff --git a/notebooks/07-backprop.ipynb b/notebooks/07-backprop.ipynb new file mode 100644 index 0000000..1fb0856 --- /dev/null +++ b/notebooks/07-backprop.ipynb @@ -0,0 +1,581 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 선형회귀 데모" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 라이브러리 import 및 설정" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "ExecuteTime": { + "end_time": "2021-09-07T07:08:09.537987Z", + "start_time": "2021-09-07T07:08:09.145618Z" + } + }, + "outputs": [], + "source": [ + "%reload_ext autoreload\n", + "%autoreload 2\n", + "%matplotlib inline" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "ExecuteTime": { + "end_time": "2021-09-07T07:08:10.292310Z", + "start_time": "2021-09-07T07:08:09.539714Z" + } + }, + "outputs": [], + "source": [ + "from matplotlib import pyplot as plt\n", + "from matplotlib import rcParams\n", + "import numpy as np\n", + "from pathlib import Path\n", + "import pandas as pd\n", + "import seaborn as sns\n", + "from tqdm.notebook import tqdm\n", + "import warnings" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "ExecuteTime": { + "end_time": "2021-09-07T07:08:10.312044Z", + "start_time": "2021-09-07T07:08:10.294207Z" + } + }, + "outputs": [], + "source": [ + "rcParams['figure.figsize'] = (16, 8)\n", + "plt.style.use('fivethirtyeight')\n", + "pd.set_option('max_columns', 100)\n", + "pd.set_option(\"display.precision\", 4)\n", + "warnings.simplefilter('ignore')\n", + "np.set_printoptions(4)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 학습데이터 로드" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "데이터는 [Dacon 단국대 소/중 데이터 분석 AI 경진대회 웹사이트](https://www.dacon.io/competitions/official/235638/data/)에서 다운로드 받아 `../input` 폴더에 저장." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "ExecuteTime": { + "end_time": "2021-09-07T07:08:10.331216Z", + "start_time": "2021-09-07T07:08:10.314059Z" + } + }, + "outputs": [], + "source": [ + "data_dir = Path('../input/')\n", + "trn_file = data_dir / 'train.csv'\n", + "seed = 42\n", + "np.random.seed(seed)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "ExecuteTime": { + "end_time": "2021-09-07T07:08:11.140422Z", + "start_time": "2021-09-07T07:08:10.332477Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(320000, 19)\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
ugrizredshiftdered_udered_gdered_rdered_idered_znObservenDetectairmass_uairmass_gairmass_rairmass_iairmass_zclass
id
023.264020.336819.009517.672416.9396-8.1086e-0523.124320.257818.955117.632116.908918181.18981.19071.18901.18941.19020
115.052114.062013.452413.268413.16894.5061e-0314.966414.004513.411413.236313.1347111.25331.25781.24881.25101.25551
216.786415.825415.536315.393515.35004.7198e-0416.607615.686615.440015.321715.2961221.02251.02411.02101.02171.02330
325.660621.188720.221219.894919.63465.8143e-0625.353620.994720.087319.794719.5552431.20541.20611.20491.20511.20570
424.453420.699219.042418.324217.9826-3.3247e-0523.771420.433818.863018.190317.875913121.19391.19431.19371.19381.19410
\n", + "
" + ], + "text/plain": [ + " u g r i z redshift dered_u dered_g \\\n", + "id \n", + "0 23.2640 20.3368 19.0095 17.6724 16.9396 -8.1086e-05 23.1243 20.2578 \n", + "1 15.0521 14.0620 13.4524 13.2684 13.1689 4.5061e-03 14.9664 14.0045 \n", + "2 16.7864 15.8254 15.5363 15.3935 15.3500 4.7198e-04 16.6076 15.6866 \n", + "3 25.6606 21.1887 20.2212 19.8949 19.6346 5.8143e-06 25.3536 20.9947 \n", + "4 24.4534 20.6992 19.0424 18.3242 17.9826 -3.3247e-05 23.7714 20.4338 \n", + "\n", + " dered_r dered_i dered_z nObserve nDetect airmass_u airmass_g \\\n", + "id \n", + "0 18.9551 17.6321 16.9089 18 18 1.1898 1.1907 \n", + "1 13.4114 13.2363 13.1347 1 1 1.2533 1.2578 \n", + "2 15.4400 15.3217 15.2961 2 2 1.0225 1.0241 \n", + "3 20.0873 19.7947 19.5552 4 3 1.2054 1.2061 \n", + "4 18.8630 18.1903 17.8759 13 12 1.1939 1.1943 \n", + "\n", + " airmass_r airmass_i airmass_z class \n", + "id \n", + "0 1.1890 1.1894 1.1902 0 \n", + "1 1.2488 1.2510 1.2555 1 \n", + "2 1.0210 1.0217 1.0233 0 \n", + "3 1.2049 1.2051 1.2057 0 \n", + "4 1.1937 1.1938 1.1941 0 " + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "trn = pd.read_csv(trn_file, index_col=0)\n", + "print(trn.shape)\n", + "trn.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "ExecuteTime": { + "end_time": "2021-09-07T07:08:11.195777Z", + "start_time": "2021-09-07T07:08:11.141764Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(200004, 1) (200004,)\n" + ] + } + ], + "source": [ + "X = trn[trn['class'] != 0][['u']].values\n", + "y = trn[trn['class'] != 0]['dered_u'].values\n", + "print(X.shape, y.shape)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## NumPy를 이용한 역전파 학습" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "ExecuteTime": { + "end_time": "2021-09-07T07:08:11.216644Z", + "start_time": "2021-09-07T07:08:11.197245Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "w1: [[-0.0013 0.0045]]\n", + "b1: [[0.0023 0.001 ]]\n", + "w2: [[-0.0034]\n", + " [-0.0034]]\n", + "b2: [[-0.0044]]\n" + ] + } + ], + "source": [ + "alpha = .001\n", + "\n", + "n_input = 1\n", + "n_hidden = 2\n", + "n_output = 1\n", + "\n", + "w1 = (np.random.rand(n_input, n_hidden) - .5) * .01\n", + "b1 = (np.random.rand(1, n_hidden) - .5) * .01\n", + "\n", + "w2 = (np.random.rand(n_hidden, n_output) - .5) * .01\n", + "b2 = (np.random.rand(1, n_output) - .5) * .01\n", + "\n", + "epoch = 5\n", + "print(f'w1: {w1}\\nb1: {b1}\\nw2: {w2}\\nb2: {b2}')" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "ExecuteTime": { + "end_time": "2021-09-07T07:08:11.305609Z", + "start_time": "2021-09-07T07:08:11.219429Z" + } + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "ef97998bab9f4209880c66a1f5c9e29a", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/5 [00:00" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plt.figure(figsize=(8, 8))\n", + "plt.scatter(X, y, alpha=.2)\n", + "plt.scatter(X, np.maximum(X @ w1 + b1, 0.) @ w2 + b2, alpha=.2, color='darkorange')\n", + "plt.xlim(0, 40)\n", + "plt.ylim(0, 40)\n", + "plt.xlabel('x')\n", + "plt.ylabel('y')\n", + "plt.legend(['data', 'prediction'])" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.5" + }, + "toc": { + "base_numbering": 1, + "nav_menu": {}, + "number_sections": true, + "sideBar": true, + "skip_h1_title": true, + "title_cell": "Table of Contents", + "title_sidebar": "Contents", + "toc_cell": false, + "toc_position": {}, + "toc_section_display": true, + "toc_window_display": true + } + }, + "nbformat": 4, + "nbformat_minor": 4 +}