From 343c9d9de8bc79d199b26348773e897d41382ae9 Mon Sep 17 00:00:00 2001 From: Jon Perl Date: Sat, 1 Apr 2017 16:49:27 -0400 Subject: [PATCH] Show the embedding in tensorboard --- .gitignore | 1 + stock2vec.ipynb | 227 ++++++++++++++++++++++++------------------------ 2 files changed, 116 insertions(+), 112 deletions(-) diff --git a/.gitignore b/.gitignore index 2291738..14666d3 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ .ipynb_checkpoints/ input/ +output/ .floyd* diff --git a/stock2vec.ipynb b/stock2vec.ipynb index d0cbaf2..68c4003 100644 --- a/stock2vec.ipynb +++ b/stock2vec.ipynb @@ -11,7 +11,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": { "collapsed": false }, @@ -23,9 +23,11 @@ "import multiprocessing as mp\n", "import numpy as np\n", "import pandas as pd\n", + "import os\n", "import random\n", "import sys\n", "import tensorflow as tf\n", + "from tensorflow.contrib.tensorboard.plugins import projector\n", "import time\n", "from functools import partial\n", "from tqdm import tqdm" @@ -40,7 +42,25 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "CSV_URL = 'https://s3.amazonaws.com/perl-ml/prices.csv?response-content-disposition=attachment&X-Amz-Security-Token=FQoDYXdzECIaDLG1ZU6Yzztd7CsNGCKsAgNa3zgOVIw%2BQB8y%2FcRAMdAYK0ZPWW59OqVSuRuFGv3NEX3LapeZnns4VZleRraw1352r%2BP1CJm2hqgg2OlGcjf8pa414x90CDCdyIemO8HJwoIr4nKi18945ZmxthTL04BJsHD1MN0Tp%2F30A3kUMqscJP68vuQ75w098gKBJFxlnKztFUnP91Myn3%2FrrNUKQ%2F%2BODJx%2Bmpu7CMOGZlDLlSHtpTKbo8pULbHFGZAe%2BAvPqq0KU71nJ%2FWjUPcbLaEjSxOZl3%2BP98cePjijlMC8O6r9JzjTqGKUUUiqOWA92QZ6UtZfUlkyO%2BcNdLGltRJrCkGEctmyhJ6Qnim0eIfSBlzhDVPAtuAdTDrXzi2d3SGOJNm8P56ak71Vnk7P%2FSyGZsdQ9G0nMXBH1GeG5yjr7ebGBQ%3D%3D&X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Date=20170328T010700Z&X-Amz-SignedHeaders=host&X-Amz-Expires=300&X-Amz-Credential=ASIAJBTQPDQAOL557TLA%2F20170328%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Signature=16a879624653ee25590a42768d975982001f3451249973af25e9d93942fec054'\n", + "FILE_NAME = 'input/prices.csv'\n", + "LOG_DIR = 'output'\n", + "MODEL_PATH = os.path.join(LOG_DIR, \"model.ckpt\")\n", + "STOCK_PATH = os.path.join(LOG_DIR,'stock.tsv')\n", + "\n", + "if not os.path.exists(LOG_DIR):\n", + " os.makedirs(LOG_DIR)" + ] + }, + { + "cell_type": "code", + "execution_count": null, "metadata": { "collapsed": false }, @@ -49,9 +69,8 @@ "from urllib.request import urlretrieve\n", "from os.path import isfile, isdir\n", "\n", - "dataset_folder_path = 'input'\n", - "dataset_filename = 'input/prices.csv'\n", - "dataset_name = 'Prices'\n", + "if not os.path.exists('input'):\n", + " os.makedirs('input')\n", "\n", "class DLProgress(tqdm):\n", " last_block = 0\n", @@ -61,41 +80,18 @@ " self.update((block_num - self.last_block) * block_size)\n", " self.last_block = block_num\n", "\n", - "if not isfile(dataset_filename):\n", - " with DLProgress(unit='B', unit_scale=True, miniters=1, desc=dataset_name) as pbar:\n", - " urlretrieve(\n", - " 'https://s3.amazonaws.com/perl-ml/prices.csv?response-content-disposition=attachment&X-Amz-Security-Token=FQoDYXdzEN3%2F%2F%2F%2F%2F%2F%2F%2F%2F%2FwEaDDGDXIQxfhjhlalnoyKqAiqteedReEObibGFinGZUTbCNLqOsBrBfhb3m%2B9WSc202KdlXdoi8bxYATvctErFAeNF%2FlVgdPlu%2BRy8dLOHw5a%2BvTNM92V8V1XiJnuYgpe69GI914L1xceQGmcJ9qQ1Fg2iSi5cGj2%2FNL26CHIOmdblBGp6VUFUqtu0ZoRb18XXYBlSGQIGk4kxGfwiN5%2BbnQNB%2FInBx0YkDI5XFOIOXa1HzF4anoHgoSSjwdq8FXLQh8LXD5mYvqkTLokIssfZeJrc4TyPy9gZW4hewwbI4NAauQvJfde2Z%2BA%2B5iV4%2B%2B8wFFcDMeM%2Fg%2BYyrTVhaRVZ%2FIU033J6CXshjaL0uHwFleXw%2FHlzMjQst2YZmQu0EqxNCowwwxcugVsKcMaPdMq%2BWJ66qWxN5DZcC3oo%2FJvbxgU%3D&X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Date=20170325T201144Z&X-Amz-SignedHeaders=host&X-Amz-Expires=300&X-Amz-Credential=ASIAIZGCD6XQ355X2AMA%2F20170325%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Signature=0c2703d3dbef5f58006a3b7e89ff85b2b86e67542b861784fbff2da48434e0df',\n", - " dataset_filename,\n", - " pbar.hook)" + "if not isfile(FILE_NAME):\n", + " with DLProgress(unit='B', unit_scale=True, miniters=1, desc='Prices') as pbar:\n", + " urlretrieve(CSV_URL, FILE_NAME, pbar.hook)" ] }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": { "collapsed": false }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "rows: 10000000it [00:14, 668326.67it/s] \n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - " adj_close date ticker epsdil pe\n", - "45526 13.745073 2001-06-18 OLED -0.87 -15.798935\n", - "3046709 6.147269 2001-06-18 YUM 0.69 8.909085\n", - "1778402 43.419875 2001-06-18 SWY 2.31 18.796483\n", - "315864 29.645033 2001-06-18 PEP 1.50 19.763356\n", - "1743245 11.213455 2001-06-18 SVU 0.47 23.858415\n" - ] - } - ], + "outputs": [], "source": [ "chunksize = 1000000\n", "price_rows = 9191528\n", @@ -137,25 +133,11 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": { "collapsed": false }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "window for 0 OLED 0\n", - "1 YUM\n", - "window for 9 SVU 2\n", - "3 SWY\n", - "4 PEP\n", - "window for 18 PEP 4\n", - "3 SWY\n" - ] - } - ], + "outputs": [], "source": [ "ticker_to_int = {}\n", "int_to_ticker = {}\n", @@ -194,21 +176,13 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": { "collapsed": false }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|██████████| 9191/9191 [01:23<00:00, 106.73it/s]" - ] - } - ], + "outputs": [], "source": [ - "batch_size = 1000\n", + "batch_size = 10000\n", "window_size = 10\n", "\n", "total_prices = len(prices)\n", @@ -241,6 +215,19 @@ "batches = get_batches()" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "# Save embedding metadata\n", + "with open(STOCK_PATH, 'w') as out:\n", + " out.write('\\n'.join(int_to_ticker.values()))" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -250,7 +237,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "metadata": { "collapsed": false }, @@ -266,7 +253,7 @@ "n_embedding = 400 # Number of embedding features \n", "\n", "with train_graph.as_default():\n", - " embedding = tf.Variable(tf.random_uniform((n_stocks, n_embedding), -1, 1))\n", + " embedding = tf.Variable(tf.random_uniform((n_stocks, n_embedding), -1, 1), name='stock_embedding')\n", " embed = tf.nn.embedding_lookup(embedding, inputs)" ] }, @@ -279,21 +266,11 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "metadata": { "collapsed": false }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\r", - " \r", - "9192it [01:40, 106.73it/s]" - ] - } - ], + "outputs": [], "source": [ "# Number of negative labels to sample\n", "n_sampled = 100\n", @@ -313,7 +290,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "metadata": { "collapsed": false }, @@ -351,41 +328,7 @@ "collapsed": false, "scrolled": false }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 1/10 Iteration: 100 Avg. Training loss: 4.7353 0.1324 sec/batch\n", - "Epoch 1/10 Iteration: 200 Avg. Training loss: 4.3297 0.1308 sec/batch\n", - "Epoch 1/10 Iteration: 300 Avg. Training loss: 3.9655 0.1325 sec/batch\n", - "Epoch 1/10 Iteration: 400 Avg. Training loss: 3.5779 0.1372 sec/batch\n", - "Epoch 1/10 Iteration: 500 Avg. Training loss: 3.4165 0.1408 sec/batch\n", - "Epoch 1/10 Iteration: 600 Avg. Training loss: 3.3865 0.1415 sec/batch\n", - "Epoch 1/10 Iteration: 700 Avg. Training loss: 3.4901 0.1380 sec/batch\n", - "Epoch 1/10 Iteration: 800 Avg. Training loss: 3.3706 0.1374 sec/batch\n", - "Epoch 1/10 Iteration: 900 Avg. Training loss: 3.1838 0.1338 sec/batch\n", - "Epoch 1/10 Iteration: 1000 Avg. Training loss: 3.2497 0.1291 sec/batch\n", - "Nearest to ACAS: LCI, MCP, AI, AIT, FLR, CTWS, BWLD, GNE,\n", - "Nearest to CLFD: A, GRT, USM, ITMN, NMRX, ESIO, CAS, LMIA,\n", - "Nearest to CSC: FITB, RFP, JAH, EQT, AME, ACAD, AAMC, CBT,\n", - "Nearest to RVLT: AMAG, PES, CPWR, HP, PLUG, DTLK, RNWK, DXLG,\n", - "Nearest to FLS: SYK, QGEN, DIS, CUR, INTC, TIF, NEM, ALEX,\n", - "Nearest to GGG: TGT, STZ, UHS, FBNK, ACHC, FINL, AFAM, MAIN,\n", - "Nearest to PRXL: PNRA, GTS, EA, SYX, BBY, BSET, HTA, JBLU,\n", - "Nearest to SPTN: FC, EGHT, TWI, SUP, CLH, CMS, GLW, CUBI,\n", - "Nearest to OFG: NL, AFOP, AXS, WGO, EPR, GLAD, SCG, CEVA,\n", - "Nearest to ASTE: CRK, SEAC, BDE, STAA, NPK, PPO, HXL, DE,\n", - "Nearest to BGC: EDIG, TIVO, PRKR, NYLD, SWC, OSUR, PLPC, KTOS,\n", - "Nearest to NANO: SGEN, RELL, EDIG, VGR, HLIT, ENR, CAMP, NDAQ,\n", - "Nearest to MTD: XLS, FLTX, MLI, DIS, MGNX, SCCO, BGFV, POWL,\n", - "Nearest to LLTC: HDNG, HLX, VAR, WOOF, LNN, CLVS, CAKE, EFII,\n", - "Nearest to TPC: CFNB, CLI, PEG, NWSA, FCX, HOT, FBC, SANM,\n", - "Nearest to ORLY: ACC, RSTI, DHR, FLR, WY, ORIT, MTD, CWST,\n", - "Epoch 1/10 Iteration: 1100 Avg. Training loss: 3.4303 0.1285 sec/batch\n" - ] - } - ], + "outputs": [], "source": [ "epochs = 10\n", "\n", @@ -418,7 +361,7 @@ " loss = 0\n", " start = time.time()\n", " \n", - " if iteration % 1000 == 0:\n", + " if iteration % 10000 == 0:\n", " # note that this is expensive (~20% slowdown if computed every 500 steps)\n", " sim = similarity.eval()\n", " for i in range(valid_size):\n", @@ -435,10 +378,70 @@ " print(log)\n", " \n", " iteration += 1\n", - " save_path = saver.save(sess, \"checkpoints/text8.ckpt\")\n", + " save_path = saver.save(sess, MODEL_PATH)\n", " embed_mat = sess.run(normalized_embedding)" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "# Save the embedding for tensorboard\n", + "\n", + "with train_graph.as_default():\n", + " saver = tf.train.Saver()\n", + "\n", + "with tf.Session(graph=train_graph) as sess:\n", + " saver.restore(sess, MODEL_PATH)\n", + " \n", + " config = projector.ProjectorConfig()\n", + "\n", + " viz_embedding = config.embeddings.add()\n", + " viz_embedding.tensor_name = embedding.name\n", + " viz_embedding.metadata_path = STOCK_PATH\n", + " summary_writer = tf.summary.FileWriter(LOG_DIR)\n", + " projector.visualize_embeddings(summary_writer, config)\n", + "\n", + " saver.save(sess, MODEL_PATH)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "%matplotlib inline\n", + "%config InlineBackend.figure_format = 'retina'\n", + "\n", + "import matplotlib.pyplot as plt\n", + "from sklearn.manifold import TSNE\n", + "\n", + "viz_stocks = 1000\n", + "tsne = TSNE()\n", + "embed_tsne = tsne.fit_transform(embed_mat[:viz_stocks, :])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "fig, ax = plt.subplots(figsize=(20, 20))\n", + "for idx in range(viz_stocks):\n", + " plt.scatter(*embed_tsne[idx, :], color='steelblue')\n", + " plt.annotate(int_to_ticker[idx], (embed_tsne[idx, 0], embed_tsne[idx, 1]), alpha=0.7)" + ] + }, { "cell_type": "code", "execution_count": null,