-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy path01_cartpole_dqn.py
105 lines (85 loc) · 3.28 KB
/
01_cartpole_dqn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
#!/usr/bin/env python3
import gym
import ptan
import numpy as np
from tensorboardX import SummaryWriter
import torch
import torch.nn as nn
import torch.optim as optim
GAMMA = 0.99
LEARNING_RATE = 0.01
BATCH_SIZE = 8
EPSILON_START = 1.0
EPSILON_STOP = 0.02
EPSILON_STEPS = 5000
REPLAY_BUFFER = 50000
class DQN(nn.Module):
def __init__(self, input_size, n_actions):
super(DQN, self).__init__()
self.net = nn.Sequential(
nn.Linear(input_size, 128),
nn.ReLU(),
nn.Linear(128, n_actions)
)
def forward(self, x):
return self.net(x)
def calc_target(net, local_reward, next_state):
if next_state is None:
return local_reward
state_v = torch.tensor([next_state], dtype=torch.float32)
next_q_v = net(state_v)
best_q = next_q_v.max(dim=1)[0].item()
return local_reward + GAMMA * best_q
if __name__ == "__main__":
env = gym.make("CartPole-v0")
writer = SummaryWriter(comment="-cartpole-dqn")
net = DQN(env.observation_space.shape[0], env.action_space.n)
print(net)
selector = ptan.actions.EpsilonGreedyActionSelector(epsilon=EPSILON_START)
agent = ptan.agent.DQNAgent(net, selector, preprocessor=ptan.agent.float32_preprocessor)
exp_source = ptan.experience.ExperienceSourceFirstLast(env, agent, gamma=GAMMA)
replay_buffer = ptan.experience.ExperienceReplayBuffer(exp_source, REPLAY_BUFFER)
optimizer = optim.Adam(net.parameters(), lr=LEARNING_RATE)
mse_loss = nn.MSELoss()
total_rewards = []
step_idx = 0
done_episodes = 0
while True:
step_idx += 1
selector.epsilon = max(EPSILON_STOP, EPSILON_START - step_idx / EPSILON_STEPS)
replay_buffer.populate(1)
if len(replay_buffer) < BATCH_SIZE:
continue
# sample batch
batch = replay_buffer.sample(BATCH_SIZE)
batch_states = [exp.state for exp in batch]
batch_actions = [exp.action for exp in batch]
batch_targets = [calc_target(net, exp.reward, exp.last_state)
for exp in batch]
# train
optimizer.zero_grad()
states_v = torch.FloatTensor(batch_states)
net_q_v = net(states_v)
target_q = net_q_v.data.numpy().copy()
target_q[range(BATCH_SIZE), batch_actions] = batch_targets
target_q_v = torch.tensor(target_q)
loss_v = mse_loss(net_q_v, target_q_v)
loss_v.backward()
optimizer.step()
# handle new rewards
new_rewards = exp_source.pop_total_rewards()
if new_rewards:
done_episodes += 1
reward = new_rewards[0]
total_rewards.append(reward)
mean_rewards = float(np.mean(total_rewards[-100:]))
print("%d: reward: %6.2f, mean_100: %6.2f, epsilon: %.2f, episodes: %d" % (
step_idx, reward, mean_rewards, selector.epsilon, done_episodes))
writer.add_scalar("reward", reward, step_idx)
writer.add_scalar("reward_100", mean_rewards, step_idx)
writer.add_scalar("epsilon", selector.epsilon, step_idx)
writer.add_scalar("episodes", done_episodes, step_idx)
if mean_rewards > 195:
print("Solved in %d steps and %d episodes!" % (step_idx, done_episodes))
break
writer.close()