-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcartpole_main.py
116 lines (96 loc) · 3.06 KB
/
cartpole_main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import math
import torch
import matplotlib
from matplotlib import pyplot as plt
from action_selection import epsilon_greedy
from deep_q_network import deep_q_network
from replay_buffer import ExperienceReplayBuffer
from training_functions import soft_update, optimize
import torch.optim as optim
from visualization import plot_rewards
import gymnasium as gym
import time
is_ipython = "inline" in matplotlib.get_backend()
plt.ion()
plt.style.use("Solarize_Light2")
# set up environment and parameters
env = gym.make("CartPole-v1")
seed = 0
obs, _ = env.reset(seed=seed)
state_space = len(obs)
n_actions = env.action_space.n
# agent parameters
initial_epsilon = 0.95
epsilon_decay = 1000
final_epsilon = 0.05
discount_factor = 0.99
# training parameters
tau = 0.005
batch_size = 128
learning_rate = 0.0001
buffer_capacity = 10_000
gpu_available = torch.backends.mps.is_available()
if gpu_available:
device = torch.device("mps")
n_episodes = 700
else:
device = torch.device("cpu")
n_episodes = 30
# somehow, training is faster on cpu, so we use this for now
device = torch.device("cpu")
# initialize policy and target nets
policy_net = deep_q_network(state_space, n_actions).to(device)
target_net = deep_q_network(state_space, n_actions).to(device)
optimizer = optim.AdamW(policy_net.parameters(), lr=learning_rate, amsgrad=True)
# training loop
epsilon = initial_epsilon
buffer = ExperienceReplayBuffer(buffer_capacity)
episode_rewards = []
epsilon_values = [epsilon]
t0 = time.time()
steps = 0
for episode in range(n_episodes):
state, _ = env.reset()
state = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0)
done = False
episode_reward = 0
while not done:
# choose action
action = epsilon_greedy(policy_net, env, state, epsilon, device)
# observe environment
next_state, reward, terminated, truncated, _ = env.step(action.item())
episode_reward += reward
reward = torch.tensor([reward], device=device)
done = terminated or truncated
if terminated:
next_state = None
else:
next_state = torch.tensor(
next_state, dtype=torch.float32, device=device
).unsqueeze(0)
buffer.add(state, action, reward, next_state)
state = next_state
epsilon = max(epsilon * 0.9996, final_epsilon)
optimize(
policy_net,
target_net,
buffer,
batch_size,
device,
discount_factor,
optimizer,
)
soft_update(policy_net, target_net, tau)
epsilon_values.append(epsilon)
episode_rewards.append(episode_reward)
plot_rewards(episode_rewards, epsilon_values, is_ipython)
env.close()
print(f"Training complete. Time elapsed: {time.time() - t0}")
plot_rewards(episode_rewards, epsilon_values, is_ipython, show_results=True)
plt.ioff()
plt.show()
# save model
torch.save(
policy_net.state_dict(),
"/Users/paultalma/Programming/Python/reinforcement-learning/classic_control/cartpole/saved_models/cart_pole_model.pth",
)