Skip to content

Commit

Permalink
Show the embedding in tensorboard
Browse files Browse the repository at this point in the history
  • Loading branch information
jperl committed Apr 1, 2017
1 parent c166bbd commit 343c9d9
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 112 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
.ipynb_checkpoints/
input/
output/
.floyd*
227 changes: 115 additions & 112 deletions stock2vec.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {
"collapsed": false
},
Expand All @@ -23,9 +23,11 @@
"import multiprocessing as mp\n",
"import numpy as np\n",
"import pandas as pd\n",
"import os\n",
"import random\n",
"import sys\n",
"import tensorflow as tf\n",
"from tensorflow.contrib.tensorboard.plugins import projector\n",
"import time\n",
"from functools import partial\n",
"from tqdm import tqdm"
Expand All @@ -40,7 +42,25 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"CSV_URL = 'https://s3.amazonaws.com/perl-ml/prices.csv?response-content-disposition=attachment&X-Amz-Security-Token=FQoDYXdzECIaDLG1ZU6Yzztd7CsNGCKsAgNa3zgOVIw%2BQB8y%2FcRAMdAYK0ZPWW59OqVSuRuFGv3NEX3LapeZnns4VZleRraw1352r%2BP1CJm2hqgg2OlGcjf8pa414x90CDCdyIemO8HJwoIr4nKi18945ZmxthTL04BJsHD1MN0Tp%2F30A3kUMqscJP68vuQ75w098gKBJFxlnKztFUnP91Myn3%2FrrNUKQ%2F%2BODJx%2Bmpu7CMOGZlDLlSHtpTKbo8pULbHFGZAe%2BAvPqq0KU71nJ%2FWjUPcbLaEjSxOZl3%2BP98cePjijlMC8O6r9JzjTqGKUUUiqOWA92QZ6UtZfUlkyO%2BcNdLGltRJrCkGEctmyhJ6Qnim0eIfSBlzhDVPAtuAdTDrXzi2d3SGOJNm8P56ak71Vnk7P%2FSyGZsdQ9G0nMXBH1GeG5yjr7ebGBQ%3D%3D&X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Date=20170328T010700Z&X-Amz-SignedHeaders=host&X-Amz-Expires=300&X-Amz-Credential=ASIAJBTQPDQAOL557TLA%2F20170328%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Signature=16a879624653ee25590a42768d975982001f3451249973af25e9d93942fec054'\n",
"FILE_NAME = 'input/prices.csv'\n",
"LOG_DIR = 'output'\n",
"MODEL_PATH = os.path.join(LOG_DIR, \"model.ckpt\")\n",
"STOCK_PATH = os.path.join(LOG_DIR,'stock.tsv')\n",
"\n",
"if not os.path.exists(LOG_DIR):\n",
" os.makedirs(LOG_DIR)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
Expand All @@ -49,9 +69,8 @@
"from urllib.request import urlretrieve\n",
"from os.path import isfile, isdir\n",
"\n",
"dataset_folder_path = 'input'\n",
"dataset_filename = 'input/prices.csv'\n",
"dataset_name = 'Prices'\n",
"if not os.path.exists('input'):\n",
" os.makedirs('input')\n",
"\n",
"class DLProgress(tqdm):\n",
" last_block = 0\n",
Expand All @@ -61,41 +80,18 @@
" self.update((block_num - self.last_block) * block_size)\n",
" self.last_block = block_num\n",
"\n",
"if not isfile(dataset_filename):\n",
" with DLProgress(unit='B', unit_scale=True, miniters=1, desc=dataset_name) as pbar:\n",
" urlretrieve(\n",
" 'https://s3.amazonaws.com/perl-ml/prices.csv?response-content-disposition=attachment&X-Amz-Security-Token=FQoDYXdzEN3%2F%2F%2F%2F%2F%2F%2F%2F%2F%2FwEaDDGDXIQxfhjhlalnoyKqAiqteedReEObibGFinGZUTbCNLqOsBrBfhb3m%2B9WSc202KdlXdoi8bxYATvctErFAeNF%2FlVgdPlu%2BRy8dLOHw5a%2BvTNM92V8V1XiJnuYgpe69GI914L1xceQGmcJ9qQ1Fg2iSi5cGj2%2FNL26CHIOmdblBGp6VUFUqtu0ZoRb18XXYBlSGQIGk4kxGfwiN5%2BbnQNB%2FInBx0YkDI5XFOIOXa1HzF4anoHgoSSjwdq8FXLQh8LXD5mYvqkTLokIssfZeJrc4TyPy9gZW4hewwbI4NAauQvJfde2Z%2BA%2B5iV4%2B%2B8wFFcDMeM%2Fg%2BYyrTVhaRVZ%2FIU033J6CXshjaL0uHwFleXw%2FHlzMjQst2YZmQu0EqxNCowwwxcugVsKcMaPdMq%2BWJ66qWxN5DZcC3oo%2FJvbxgU%3D&X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Date=20170325T201144Z&X-Amz-SignedHeaders=host&X-Amz-Expires=300&X-Amz-Credential=ASIAIZGCD6XQ355X2AMA%2F20170325%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Signature=0c2703d3dbef5f58006a3b7e89ff85b2b86e67542b861784fbff2da48434e0df',\n",
" dataset_filename,\n",
" pbar.hook)"
"if not isfile(FILE_NAME):\n",
" with DLProgress(unit='B', unit_scale=True, miniters=1, desc='Prices') as pbar:\n",
" urlretrieve(CSV_URL, FILE_NAME, pbar.hook)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"rows: 10000000it [00:14, 668326.67it/s] \n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" adj_close date ticker epsdil pe\n",
"45526 13.745073 2001-06-18 OLED -0.87 -15.798935\n",
"3046709 6.147269 2001-06-18 YUM 0.69 8.909085\n",
"1778402 43.419875 2001-06-18 SWY 2.31 18.796483\n",
"315864 29.645033 2001-06-18 PEP 1.50 19.763356\n",
"1743245 11.213455 2001-06-18 SVU 0.47 23.858415\n"
]
}
],
"outputs": [],
"source": [
"chunksize = 1000000\n",
"price_rows = 9191528\n",
Expand Down Expand Up @@ -137,25 +133,11 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"window for 0 OLED 0\n",
"1 YUM\n",
"window for 9 SVU 2\n",
"3 SWY\n",
"4 PEP\n",
"window for 18 PEP 4\n",
"3 SWY\n"
]
}
],
"outputs": [],
"source": [
"ticker_to_int = {}\n",
"int_to_ticker = {}\n",
Expand Down Expand Up @@ -194,21 +176,13 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 9191/9191 [01:23<00:00, 106.73it/s]"
]
}
],
"outputs": [],
"source": [
"batch_size = 1000\n",
"batch_size = 10000\n",
"window_size = 10\n",
"\n",
"total_prices = len(prices)\n",
Expand Down Expand Up @@ -241,6 +215,19 @@
"batches = get_batches()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# Save embedding metadata\n",
"with open(STOCK_PATH, 'w') as out:\n",
" out.write('\\n'.join(int_to_ticker.values()))"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand All @@ -250,7 +237,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": null,
"metadata": {
"collapsed": false
},
Expand All @@ -266,7 +253,7 @@
"n_embedding = 400 # Number of embedding features \n",
"\n",
"with train_graph.as_default():\n",
" embedding = tf.Variable(tf.random_uniform((n_stocks, n_embedding), -1, 1))\n",
" embedding = tf.Variable(tf.random_uniform((n_stocks, n_embedding), -1, 1), name='stock_embedding')\n",
" embed = tf.nn.embedding_lookup(embedding, inputs)"
]
},
Expand All @@ -279,21 +266,11 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"\r",
" \r",
"9192it [01:40, 106.73it/s]"
]
}
],
"outputs": [],
"source": [
"# Number of negative labels to sample\n",
"n_sampled = 100\n",
Expand All @@ -313,7 +290,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": null,
"metadata": {
"collapsed": false
},
Expand Down Expand Up @@ -351,41 +328,7 @@
"collapsed": false,
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/10 Iteration: 100 Avg. Training loss: 4.7353 0.1324 sec/batch\n",
"Epoch 1/10 Iteration: 200 Avg. Training loss: 4.3297 0.1308 sec/batch\n",
"Epoch 1/10 Iteration: 300 Avg. Training loss: 3.9655 0.1325 sec/batch\n",
"Epoch 1/10 Iteration: 400 Avg. Training loss: 3.5779 0.1372 sec/batch\n",
"Epoch 1/10 Iteration: 500 Avg. Training loss: 3.4165 0.1408 sec/batch\n",
"Epoch 1/10 Iteration: 600 Avg. Training loss: 3.3865 0.1415 sec/batch\n",
"Epoch 1/10 Iteration: 700 Avg. Training loss: 3.4901 0.1380 sec/batch\n",
"Epoch 1/10 Iteration: 800 Avg. Training loss: 3.3706 0.1374 sec/batch\n",
"Epoch 1/10 Iteration: 900 Avg. Training loss: 3.1838 0.1338 sec/batch\n",
"Epoch 1/10 Iteration: 1000 Avg. Training loss: 3.2497 0.1291 sec/batch\n",
"Nearest to ACAS: LCI, MCP, AI, AIT, FLR, CTWS, BWLD, GNE,\n",
"Nearest to CLFD: A, GRT, USM, ITMN, NMRX, ESIO, CAS, LMIA,\n",
"Nearest to CSC: FITB, RFP, JAH, EQT, AME, ACAD, AAMC, CBT,\n",
"Nearest to RVLT: AMAG, PES, CPWR, HP, PLUG, DTLK, RNWK, DXLG,\n",
"Nearest to FLS: SYK, QGEN, DIS, CUR, INTC, TIF, NEM, ALEX,\n",
"Nearest to GGG: TGT, STZ, UHS, FBNK, ACHC, FINL, AFAM, MAIN,\n",
"Nearest to PRXL: PNRA, GTS, EA, SYX, BBY, BSET, HTA, JBLU,\n",
"Nearest to SPTN: FC, EGHT, TWI, SUP, CLH, CMS, GLW, CUBI,\n",
"Nearest to OFG: NL, AFOP, AXS, WGO, EPR, GLAD, SCG, CEVA,\n",
"Nearest to ASTE: CRK, SEAC, BDE, STAA, NPK, PPO, HXL, DE,\n",
"Nearest to BGC: EDIG, TIVO, PRKR, NYLD, SWC, OSUR, PLPC, KTOS,\n",
"Nearest to NANO: SGEN, RELL, EDIG, VGR, HLIT, ENR, CAMP, NDAQ,\n",
"Nearest to MTD: XLS, FLTX, MLI, DIS, MGNX, SCCO, BGFV, POWL,\n",
"Nearest to LLTC: HDNG, HLX, VAR, WOOF, LNN, CLVS, CAKE, EFII,\n",
"Nearest to TPC: CFNB, CLI, PEG, NWSA, FCX, HOT, FBC, SANM,\n",
"Nearest to ORLY: ACC, RSTI, DHR, FLR, WY, ORIT, MTD, CWST,\n",
"Epoch 1/10 Iteration: 1100 Avg. Training loss: 3.4303 0.1285 sec/batch\n"
]
}
],
"outputs": [],
"source": [
"epochs = 10\n",
"\n",
Expand Down Expand Up @@ -418,7 +361,7 @@
" loss = 0\n",
" start = time.time()\n",
" \n",
" if iteration % 1000 == 0:\n",
" if iteration % 10000 == 0:\n",
" # note that this is expensive (~20% slowdown if computed every 500 steps)\n",
" sim = similarity.eval()\n",
" for i in range(valid_size):\n",
Expand All @@ -435,10 +378,70 @@
" print(log)\n",
" \n",
" iteration += 1\n",
" save_path = saver.save(sess, \"checkpoints/text8.ckpt\")\n",
" save_path = saver.save(sess, MODEL_PATH)\n",
" embed_mat = sess.run(normalized_embedding)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# Save the embedding for tensorboard\n",
"\n",
"with train_graph.as_default():\n",
" saver = tf.train.Saver()\n",
"\n",
"with tf.Session(graph=train_graph) as sess:\n",
" saver.restore(sess, MODEL_PATH)\n",
" \n",
" config = projector.ProjectorConfig()\n",
"\n",
" viz_embedding = config.embeddings.add()\n",
" viz_embedding.tensor_name = embedding.name\n",
" viz_embedding.metadata_path = STOCK_PATH\n",
" summary_writer = tf.summary.FileWriter(LOG_DIR)\n",
" projector.visualize_embeddings(summary_writer, config)\n",
"\n",
" saver.save(sess, MODEL_PATH)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"%matplotlib inline\n",
"%config InlineBackend.figure_format = 'retina'\n",
"\n",
"import matplotlib.pyplot as plt\n",
"from sklearn.manifold import TSNE\n",
"\n",
"viz_stocks = 1000\n",
"tsne = TSNE()\n",
"embed_tsne = tsne.fit_transform(embed_mat[:viz_stocks, :])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"fig, ax = plt.subplots(figsize=(20, 20))\n",
"for idx in range(viz_stocks):\n",
" plt.scatter(*embed_tsne[idx, :], color='steelblue')\n",
" plt.annotate(int_to_ticker[idx], (embed_tsne[idx, 0], embed_tsne[idx, 1]), alpha=0.7)"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down

0 comments on commit 343c9d9

Please sign in to comment.