-
Notifications
You must be signed in to change notification settings - Fork 364
/
Copy pathmain.py
158 lines (139 loc) · 6.75 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import os
import sys
import logging
import argparse
import json
from quantylab.rltrader import settings
from quantylab.rltrader import utils
from quantylab.rltrader import data_manager
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--mode', choices=['train', 'test', 'update', 'predict'], default='train')
parser.add_argument('--ver', choices=['v1', 'v2', 'v3', 'v4'], default='v2')
parser.add_argument('--name', default=utils.get_time_str())
parser.add_argument('--stock_code', nargs='+')
parser.add_argument('--rl_method', choices=['dqn', 'pg', 'ac', 'a2c', 'a3c', 'monkey'])
parser.add_argument('--net', choices=['dnn', 'lstm', 'cnn', 'monkey'], default='dnn')
parser.add_argument('--backend', choices=['pytorch', 'tensorflow', 'plaidml'], default='pytorch')
parser.add_argument('--start_date', default='20200101')
parser.add_argument('--end_date', default='20201231')
parser.add_argument('--lr', type=float, default=0.0001)
parser.add_argument('--discount_factor', type=float, default=0.7)
parser.add_argument('--balance', type=int, default=100000000)
args = parser.parse_args()
# 학습기 파라미터 설정
output_name = f'{args.mode}_{args.name}_{args.rl_method}_{args.net}'
learning = args.mode in ['train', 'update']
reuse_models = args.mode in ['test', 'update', 'predict']
value_network_name = f'{args.name}_{args.rl_method}_{args.net}_value.mdl'
policy_network_name = f'{args.name}_{args.rl_method}_{args.net}_policy.mdl'
start_epsilon = 1 if args.mode in ['train', 'update'] else 0
num_epoches = 1000 if args.mode in ['train', 'update'] else 1
num_steps = 5 if args.net in ['lstm', 'cnn'] else 1
# Backend 설정
os.environ['RLTRADER_BACKEND'] = args.backend
if args.backend == 'tensorflow':
os.environ['KERAS_BACKEND'] = 'tensorflow'
elif args.backend == 'plaidml':
os.environ['KERAS_BACKEND'] = 'plaidml.keras.backend'
# 출력 경로 생성
output_path = os.path.join(settings.BASE_DIR, 'output', output_name)
if not os.path.isdir(output_path):
os.makedirs(output_path)
# 파라미터 기록
params = json.dumps(vars(args))
with open(os.path.join(output_path, 'params.json'), 'w') as f:
f.write(params)
# 모델 경로 준비
# 모델 포멧은 TensorFlow는 h5, PyTorch는 pickle
value_network_path = os.path.join(settings.BASE_DIR, 'models', value_network_name)
policy_network_path = os.path.join(settings.BASE_DIR, 'models', policy_network_name)
# 로그 기록 설정
log_path = os.path.join(output_path, f'{output_name}.log')
if os.path.exists(log_path):
os.remove(log_path)
logging.basicConfig(format='%(message)s')
logger = logging.getLogger(settings.LOGGER_NAME)
logger.setLevel(logging.DEBUG)
logger.propagate = False
stream_handler = logging.StreamHandler(sys.stdout)
stream_handler.setLevel(logging.INFO)
file_handler = logging.FileHandler(filename=log_path, encoding='utf-8')
file_handler.setLevel(logging.DEBUG)
logger.addHandler(stream_handler)
logger.addHandler(file_handler)
logger.info(params)
# Backend 설정, 로그 설정을 먼저하고 RLTrader 모듈들을 이후에 임포트해야 함
from quantylab.rltrader.learners import ReinforcementLearner, DQNLearner, \
PolicyGradientLearner, ActorCriticLearner, A2CLearner, A3CLearner
common_params = {}
list_stock_code = []
list_chart_data = []
list_training_data = []
list_min_trading_price = []
list_max_trading_price = []
for stock_code in args.stock_code:
# 차트 데이터, 학습 데이터 준비
chart_data, training_data = data_manager.load_data(
stock_code, args.start_date, args.end_date, ver=args.ver)
assert len(chart_data) >= num_steps
# 최소/최대 단일 매매 금액 설정
min_trading_price = 100000
max_trading_price = 10000000
# 공통 파라미터 설정
common_params = {'rl_method': args.rl_method,
'net': args.net, 'num_steps': num_steps, 'lr': args.lr,
'balance': args.balance, 'num_epoches': num_epoches,
'discount_factor': args.discount_factor, 'start_epsilon': start_epsilon,
'output_path': output_path, 'reuse_models': reuse_models}
# 강화학습 시작
learner = None
if args.rl_method != 'a3c':
common_params.update({'stock_code': stock_code,
'chart_data': chart_data,
'training_data': training_data,
'min_trading_price': min_trading_price,
'max_trading_price': max_trading_price})
if args.rl_method == 'dqn':
learner = DQNLearner(**{**common_params,
'value_network_path': value_network_path})
elif args.rl_method == 'pg':
learner = PolicyGradientLearner(**{**common_params,
'policy_network_path': policy_network_path})
elif args.rl_method == 'ac':
learner = ActorCriticLearner(**{**common_params,
'value_network_path': value_network_path,
'policy_network_path': policy_network_path})
elif args.rl_method == 'a2c':
learner = A2CLearner(**{**common_params,
'value_network_path': value_network_path,
'policy_network_path': policy_network_path})
elif args.rl_method == 'monkey':
common_params['net'] = args.rl_method
common_params['num_epoches'] = 10
common_params['start_epsilon'] = 1
learning = False
learner = ReinforcementLearner(**common_params)
else:
list_stock_code.append(stock_code)
list_chart_data.append(chart_data)
list_training_data.append(training_data)
list_min_trading_price.append(min_trading_price)
list_max_trading_price.append(max_trading_price)
if args.rl_method == 'a3c':
learner = A3CLearner(**{
**common_params,
'list_stock_code': list_stock_code,
'list_chart_data': list_chart_data,
'list_training_data': list_training_data,
'list_min_trading_price': list_min_trading_price,
'list_max_trading_price': list_max_trading_price,
'value_network_path': value_network_path,
'policy_network_path': policy_network_path})
assert learner is not None
if args.mode in ['train', 'test', 'update']:
learner.run(learning=learning)
if args.mode in ['train', 'update']:
learner.save_models()
elif args.mode == 'predict':
learner.predict()