forked from lilianweng/stock-rnn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
121 lines (94 loc) · 4.2 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import os
import pandas as pd
import pprint
import tensorflow as tf
import tensorflow.contrib.slim as slim
from data_model import StockDataSet
from model_rnn import LstmRNN
flags = tf.app.flags
flags.DEFINE_integer("stock_count", 100, "Stock count [100]")
flags.DEFINE_integer("input_size", 5, "Input size [5]")
flags.DEFINE_integer("num_steps", 30, "Num of steps [30]")
flags.DEFINE_integer("num_layers", 1, "Num of layer [1]")
flags.DEFINE_integer("lstm_size", 128, "Size of one LSTM cell [128]")
flags.DEFINE_integer("batch_size", 64, "The size of batch images [64]")
flags.DEFINE_float("keep_prob", 0.8, "Keep probability of dropout layer. [0.8]")
flags.DEFINE_float("init_learning_rate", 0.001, "Initial learning rate at early stage. [0.001]")
flags.DEFINE_float("learning_rate_decay", 0.99, "Decay rate of learning rate. [0.99]")
flags.DEFINE_integer("init_epoch", 5, "Num. of epoches considered as early stage. [5]")
flags.DEFINE_integer("max_epoch", 50, "Total training epoches. [50]")
flags.DEFINE_integer("embed_size", None, "If provided, use embedding vector of this size. [None]")
flags.DEFINE_string("stock_symbol", None, "Target stock symbol [None]")
flags.DEFINE_string("checkpoint_dir", "checkpoints", "Directory name to save the checkpoints [checkpoints]")
flags.DEFINE_integer("sample_size", 4, "Number of stocks to plot during training. [4]")
flags.DEFINE_string("plot_dir", "images", "Directory name to save plots [images]")
flags.DEFINE_boolean("train", False, "True for training, False for testing [False]")
FLAGS = flags.FLAGS
pp = pprint.PrettyPrinter()
if not os.path.exists("logs"):
os.mkdir("logs")
def show_all_variables():
model_vars = tf.trainable_variables()
slim.model_analyzer.analyze_vars(model_vars, print_info=True)
def load_sp500(input_size, num_steps, k=None, target_symbol=None, test_ratio=0.05):
if target_symbol is not None:
return [
StockDataSet(
target_symbol,
input_size=input_size,
num_steps=num_steps,
test_ratio=test_ratio)
]
# Load metadata of s & p 500 stocks
info = pd.read_csv("data/constituents-financials.csv")
info = info.rename(columns={col: col.lower().replace(' ', '_') for col in info.columns})
info['file_exists'] = info['symbol'].map(lambda x: os.path.exists("data/{}.csv".format(x)))
print info['file_exists'].value_counts().to_dict()
info = info[info['file_exists'] == True].reset_index(drop=True)
info = info.sort('market_cap', ascending=False).reset_index(drop=True)
if k is not None:
info = info.head(k)
print "Head of S&P 500 info:\n", info.head()
# Generate embedding meta file
info[['symbol', 'sector']].to_csv(os.path.join("logs/metadata.tsv"), sep='\t', index=False)
return [
StockDataSet(row['symbol'],
input_size=input_size,
num_steps=num_steps,
test_ratio=0.05)
for _, row in info.iterrows()]
def main(_):
pp.pprint(flags.FLAGS.__flags)
if not os.path.exists(FLAGS.checkpoint_dir):
os.makedirs(FLAGS.checkpoint_dir)
if not os.path.exists(FLAGS.plot_dir):
os.makedirs(FLAGS.plot_dir)
# gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.333)
run_config = tf.ConfigProto()
run_config.gpu_options.allow_growth = True
with tf.Session(config=run_config) as sess:
rnn_model = LstmRNN(
sess,
FLAGS.stock_count,
lstm_size=FLAGS.lstm_size,
num_layers=FLAGS.num_layers,
num_steps=FLAGS.num_steps,
input_size=FLAGS.input_size,
keep_prob=FLAGS.keep_prob,
embed_size=FLAGS.embed_size,
checkpoint_dir=FLAGS.checkpoint_dir,
)
show_all_variables()
stock_data_list = load_sp500(
FLAGS.input_size,
FLAGS.num_steps,
k=FLAGS.stock_count,
target_symbol=FLAGS.stock_symbol,
)
if FLAGS.train:
rnn_model.train(stock_data_list, FLAGS)
else:
if not rnn_model.load()[0]:
raise Exception("[!] Train a model first, then run test mode")
if __name__ == '__main__':
tf.app.run()