forked from AI4Finance-Foundation/FinRL_Podracer
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
58 lines (49 loc) · 1.99 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
from elegant_finrl.run import *
from elegant_finrl.agent import AgentPPO, AgentDDPG
from elegant_finrl.env import StockTradingEnv
import elegant_finrl.config as config
import yfinance as yf
from stockstats import StockDataFrame as Sdf
# Agent
args = Arguments(if_on_policy=True)
args.agent = AgentPPO() # AgentSAC(), AgentTD3(), AgentDDPG()
args.agent.if_use_gae = True
args.agent.lambda_entropy = 0.04
# Environment
tickers = config.NAS_100_TICKER
tech_indicator_list = [
'macd', 'boll_ub', 'boll_lb', 'rsi_30', 'cci_30', 'dx_30',
'close_30_sma', 'close_60_sma'] # finrl.config.TECHNICAL_INDICATORS_LIST
gamma = 0.99
max_stock = 1e2
initial_capital = 1e6
initial_stocks = np.zeros(len(tickers), dtype=np.float32)
buy_cost_pct = 1e-3
sell_cost_pct = 1e-3
start_date = '2009-01-01'
start_eval_date = '2019-01-01'
end_eval_date = '2021-01-01'
args.env = StockTradingEnv('./', gamma, max_stock, initial_capital, buy_cost_pct,
sell_cost_pct, start_date, start_eval_date,
end_eval_date, tickers, tech_indicator_list,
initial_stocks, reward_scaling=2 ** -14, if_eval=False)
args.env_eval = StockTradingEnv('./', gamma, max_stock, initial_capital, buy_cost_pct,
sell_cost_pct, start_date, start_eval_date,
end_eval_date, tickers, tech_indicator_list,
initial_stocks, reward_scaling=2 ** -14, if_eval=True)
args.env.target_reward = 2
args.env_eval.target_reward = 2
# Hyperparameters
args.gamma = gamma
args.break_step = int(2e5)
args.net_dim = 2 ** 9
args.max_step = args.env.max_step
args.max_memo = args.max_step * 4
args.batch_size = 2 ** 10
args.repeat_times = 2 ** 3
args.eval_gap = 2 ** 4
args.eval_times1 = 2 ** 3
args.eval_times2 = 2 ** 5
args.if_allow_break = False
args.rollout_num = 2 # the number of rollout workers (larger is not always faster)
train_and_evaluate_mp(args) # the training process will terminate once it reaches the target reward.