-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path__main__.py
117 lines (98 loc) · 4.5 KB
/
__main__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
# coding=utf-8
"""
Invoke REINFORCE agent using
python -m BasicPolicyGradient [--help] [--train] [--render_training] [--discounted] [--play_for] [--test_run]
Note on TensorBoard usage:
Start TensorBoard in terminal:
cd DRLimplementation
tensorboard --logdir=BasicPolicyGradient/graph
In browser, go to:
http://0.0.0.0:6006/
"""
import argparse
import tensorflow as tf
from BasicPolicyGradient.REINFORCEagent import REINFORCEagent
from blocAndTools.buildingbloc import ExperimentSpec
""" Note: About Gamma value (aka the discout factor)
| Big difference between 0.9 and 0.999.
| Also you need to take into account the experiment average number of step per episode
|
| Example of 'discounted return' over 100 timestep:
| 0.9^100 --> 0.000026 vs 0.99^100 --> 0.366003 vs 0.999^100 --> 0.904792
|
| Meaning a agent with Gamma=0.9 is short-sighted and one with Gamma=0.9999 is farsighted or clairvoyant
"""
cartpole_hparam = {
'paramameter_set_name': 'Basic',
'algo_name': 'REINFORCE',
'comment': None,
'prefered_environment': 'CartPole-v0',
'expected_reward_goal': 200,
'batch_size_in_ts': 5000,
'max_epoch': 40,
'discounted_reward_to_go': True,
'discout_factor': 0.999,
'learning_rate': 1e-2,
'theta_nn_h_layer_topo': (62,),
'random_seed': 82,
'theta_hidden_layers_activation': tf.nn.tanh, # tf.nn.relu,
'theta_output_layers_activation': None,
'render_env_every_What_epoch': 100,
'print_metric_every_what_epoch': 2,
'isTestRun': False,
'show_plot': False,
}
test_hparam = {
'paramameter_set_name': 'Basic',
'algo_name': 'REINFORCE',
'comment': 'TestSpec',
'prefered_environment': 'CartPole-v0',
'expected_reward_goal': 200,
'batch_size_in_ts': 1000,
'max_epoch': 5,
'discounted_reward_to_go': True,
'discout_factor': 0.999,
'learning_rate': 1e-2,
'theta_nn_h_layer_topo': (2,),
'random_seed': 82,
'theta_hidden_layers_activation': tf.nn.tanh,
'theta_output_layers_activation': None,
'render_env_every_What_epoch': 5,
'print_metric_every_what_epoch': 2,
'isTestRun': True,
'show_plot': False,
}
parser = argparse.ArgumentParser(description=(
"=============================================================================\n"
":: Command line option for the REINFORCE Agent.\n\n"
" The agent will play by default using previously trained computation graph.\n"
" You can execute training by using the argument: --train "),
epilog="=============================================================================\n")
# parser.add_argument('--env', type=str, default='CartPole-v0')
parser.add_argument('--train', action='store_true', help='Execute training of REINFORCE agent')
parser.add_argument('-r', '--render_training', action='store_true',
help='(Training option) Watch the agent execute trajectories while he is on traning duty')
parser.add_argument('-d', '--discounted', default=None, type=bool,
help='(Training option) Force training execution with discounted reward-to-go')
parser.add_argument('-p', '--play_for', type=int, default=20,
help='(Playing option) Max playing trajectory, default=20')
parser.add_argument('--test_run', action='store_true')
args = parser.parse_args()
exp_spec = ExperimentSpec()
if args.train:
# Configure experiment hyper-parameter
if args.test_run:
exp_spec.set_experiment_spec(test_hparam)
else:
exp_spec.set_experiment_spec(cartpole_hparam)
if args.discounted is not None:
exp_spec.set_experiment_spec({'discounted_reward_to_go': args.discounted})
reinforce_agent = REINFORCEagent(exp_spec)
reinforce_agent.train(render_env=args.render_training)
else:
exp_spec.set_experiment_spec(cartpole_hparam)
if args.test_run:
exp_spec.set_experiment_spec({'isTestRun': True})
reinforce_agent = REINFORCEagent(exp_spec)
reinforce_agent.play(run_name='REINFORCE_agent-39', max_trajectories=args.play_for)
exit(0)