-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy patha2c.py
80 lines (63 loc) · 1.96 KB
/
a2c.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
'''
File name: a2c_train.py
Author: Jayson Ng
Email: [email protected]
Date created: 15/7/2021
Python Version: 3.7
'''
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm
import gym
import matplotlib.pyplot as plt
import copy
if __name__ == '__main__':
env = gym.make('CartPole-v0')
env = env.unwrapped
HIDDEN_LAYER = 128 # NN hidden layer size
# Hyper-parameters
log_intv = 10
capacity = 10000
target_update_intv = 500 # in terms of iterations
max_episodes = 5000
max_steps = 500
lr = 0.008
discount_factor = 0.99
batch_size = 256
goal = 200
device = 'cuda' if torch.cuda.is_available() else 'cpu'
net = A2CNetwork(env.observation_space.high.shape[0], env.action_space.n, hid_size=HIDDEN_LAYER).to(device)
agent = TDAgent(net, capacity, env.action_space.n, batch_size, discount_factor, lr, target_update_intv)
losses = []
reward_hist = []
avg_reward_hist = []
avg_reward = 8
best_avg_reward = 0
for episode_i in tqdm(range(max_episodes)):
s = env.reset()
if np.mean(reward_hist[-100:]) >= goal: # benchmark of cartpole-v0 problem
print(f'Solved! Average Reward reaches {goal} over the past 100 runs')
break
ep_reward = 0
ep_loss = 0
for si in range(max_steps):
a, a_prob = agent.select_action(s)
new_s, r, done, info = env.step(a)
if done:
r = -1
agent.store_transition(s, a, new_s, r, a_prob)
if done:
reward_hist.append(ep_reward)
break
s = new_s
ep_reward += 1
loss = agent.learn()
if loss is not None:
ep_loss += loss
losses.append(ep_loss/(si+1))
avg_reward = int(0.95 * avg_reward + 0.05 * ep_reward)
avg_reward_hist.append(avg_reward)
if episode_i % log_intv == 0:
print(f'Episode {episode_i} | Reward: {ep_reward} | Avg Reward: {avg_reward} | Loss: {loss}')