-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathExperiment_policy.py
103 lines (75 loc) · 3.58 KB
/
Experiment_policy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import numpy as np
import time
from DQN_combined import DQNAgent
import gym
from Learning_Curve import LearningCurvePlot, smooth_curve
def average_over_repetitions(state_n, action_n, hidden_n, learning_rate, gamma, memory_size, batch_size, target_update_interval,
episodes, policy, epsilon, eval_interval, temperature, use_memory,
use_target_network, n_repetitions, smoothing_window):
returns_over_repetitions = []
now = time.time()
env = gym.make('CartPole-v1')
print("Running Policy Experiment...")
for rep in range(n_repetitions): # Loop over repetitions
if policy == 'e-greedy':
print("DQN e-greedy repetition {}/{}".format(rep, n_repetitions + 1))
agent = DQNAgent(state_n, action_n, hidden_n, learning_rate, gamma, memory_size, batch_size, target_update_interval,
episodes, policy, epsilon, eval_interval, temperature, use_memory,
use_target_network)
episode_rewards = agent.fit(env)
elif policy == 'softmax':
print("DQN softmax repetition {}/{}".format(rep, n_repetitions))
agent = DQNAgent(state_n, action_n, hidden_n, learning_rate, gamma, memory_size, batch_size, target_update_interval,
episodes, policy, epsilon, eval_interval, temperature, use_memory,
use_target_network)
episode_rewards = agent.fit(env)
returns_over_repetitions.append(episode_rewards)
# print('Running one setting takes {} minutes'.format((time.time()-now)/60))
print("This experiment took {} minutes".format((time.time() - now) / 60))
learning_curve = np.mean(np.array(returns_over_repetitions), axis=0)
learning_curve = smooth_curve(learning_curve, smoothing_window)
return learning_curve
def experiment():
n_repetitions = 20
smoothing_window = 9
state_n = 4
action_n = 2
target_update_interval = 100
memory_size = 10000
episodes = 1000
eval_interval = 100
use_memory = True
use_target_network = True
# 2 implicit values for epsilon and temperature
epsilon = 0.1
temperature =0.2
'''Tuned HPs for e-greedy'''
policy = 'e-greedy'
learning_rate = 0.001
hidden_n = 256
batch_size = 64
epsilon = 0.1
gamma = 0.95
Plot = LearningCurvePlot(title='Policy Analysis')
# DQN e-greedy
learning_curve = average_over_repetitions(state_n, action_n, hidden_n, learning_rate, gamma, memory_size, batch_size,
target_update_interval, episodes, policy, epsilon, eval_interval, temperature,
use_memory, use_target_network,
n_repetitions, smoothing_window)
Plot.add_curve(learning_curve, label="DQN e-greedy")
'''Tuned HPs for softmax'''
policy = 'softmax'
learning_rate = 0.01
hidden_n = 256
batch_size = 32
temperature = 0.01
gamma = 1
# DQN softmax
learning_curve = average_over_repetitions(state_n, action_n, hidden_n, learning_rate, gamma, memory_size, batch_size,
target_update_interval, episodes, policy, epsilon, eval_interval, temperature,
use_memory, use_target_network,
n_repetitions, smoothing_window)
Plot.add_curve(learning_curve, label="DQN softmax")
Plot.save('experiment_policy.png')
if __name__ == '__main__':
experiment()