Skip to content

Commit 43440cd

Browse files
author
xuhaoran4
committed
add
1 parent 8cca5ad commit 43440cd

File tree

4 files changed

+321
-234
lines changed

4 files changed

+321
-234
lines changed

BCQ_L.py

+203
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
import copy
2+
import numpy as np
3+
import torch
4+
import torch.nn as nn
5+
import torch.nn.functional as F
6+
7+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8+
9+
10+
class Actor(nn.Module):
11+
def __init__(self, state_dim, action_dim, max_action, hidden_unit=256, phi=0.05):
12+
super(Actor, self).__init__()
13+
self.l1 = nn.Linear(state_dim + action_dim, hidden_unit)
14+
self.l2 = nn.Linear(hidden_unit, hidden_unit)
15+
self.l3 = nn.Linear(hidden_unit, action_dim)
16+
17+
self.max_action = max_action
18+
self.phi = phi
19+
20+
def forward(self, state, action):
21+
a = F.relu(self.l1(torch.cat([state, action], 1)))
22+
a = F.relu(self.l2(a))
23+
a = self.phi * self.max_action * torch.tanh(self.l3(a))
24+
return (a + action).clamp(-self.max_action, self.max_action)
25+
26+
27+
class Critic(nn.Module):
28+
def __init__(self, state_dim, action_dim, hidden_unit=256):
29+
super(Critic, self).__init__()
30+
self.l1 = nn.Linear(state_dim + action_dim, hidden_unit)
31+
self.l2 = nn.Linear(hidden_unit, hidden_unit)
32+
self.l3 = nn.Linear(hidden_unit, 1)
33+
34+
self.l4 = nn.Linear(state_dim + action_dim, hidden_unit)
35+
self.l5 = nn.Linear(hidden_unit, hidden_unit)
36+
self.l6 = nn.Linear(hidden_unit, 1)
37+
38+
def forward(self, state, action):
39+
q1 = F.relu(self.l1(torch.cat([state, action], 1)))
40+
q1 = F.relu(self.l2(q1))
41+
q1 = self.l3(q1)
42+
43+
q2 = F.relu(self.l4(torch.cat([state, action], 1)))
44+
q2 = F.relu(self.l5(q2))
45+
q2 = self.l6(q2)
46+
return q1, q2
47+
48+
def q1(self, state, action):
49+
q1 = F.relu(self.l1(torch.cat([state, action], 1)))
50+
q1 = F.relu(self.l2(q1))
51+
q1 = self.l3(q1)
52+
return q1
53+
54+
55+
# Vanilla Variational Auto-Encoder
56+
class VAE(nn.Module):
57+
def __init__(self, state_dim, action_dim, latent_dim, max_action, hidden_unit=256):
58+
super(VAE, self).__init__()
59+
self.e1 = nn.Linear(state_dim + action_dim, hidden_unit)
60+
self.e2 = nn.Linear(hidden_unit, hidden_unit)
61+
62+
self.mean = nn.Linear(hidden_unit, latent_dim)
63+
self.log_std = nn.Linear(hidden_unit, latent_dim)
64+
65+
self.d1 = nn.Linear(state_dim + latent_dim, hidden_unit)
66+
self.d2 = nn.Linear(hidden_unit, hidden_unit)
67+
self.d3 = nn.Linear(hidden_unit, action_dim)
68+
69+
self.max_action = max_action
70+
self.latent_dim = latent_dim
71+
72+
def forward(self, state, action):
73+
z = F.relu(self.e1(torch.cat([state, action], 1)))
74+
z = F.relu(self.e2(z))
75+
76+
mean = self.mean(z)
77+
# Clamped for numerical stability
78+
log_std = self.log_std(z).clamp(-4, 15)
79+
std = torch.exp(log_std)
80+
z = mean + std * torch.randn_like(std)
81+
82+
u = self.decode(state, z)
83+
84+
return u, mean, std
85+
86+
def decode(self, state, z=None):
87+
# When sampling from the VAE, the latent vector is clipped to [-0.5, 0.5]
88+
if z is None:
89+
z = torch.randn((state.shape[0], self.latent_dim)).to(device).clamp(-0.5, 0.5)
90+
91+
a = F.relu(self.d1(torch.cat([state, z], 1)))
92+
a = F.relu(self.d2(a))
93+
return self.max_action * torch.tanh(self.d3(a))
94+
95+
96+
class BCQ(object):
97+
def __init__(self, state_dim, action_dim, max_action, discount=0.99, tau=0.005, lmbda=0.75, phi=0.05):
98+
latent_dim = action_dim * 2
99+
100+
self.actor = Actor(state_dim, action_dim, max_action, phi=phi).to(device)
101+
self.actor_target = copy.deepcopy(self.actor)
102+
self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=1e-3)
103+
104+
self.critic = Critic(state_dim, action_dim).to(device)
105+
self.critic_target = copy.deepcopy(self.critic)
106+
self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=1e-3)
107+
108+
self.cost_critic = Critic(state_dim, action_dim).to(device)
109+
self.cost_critic_target = copy.deepcopy(self.cost_critic)
110+
self.cost_critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=1e-3)
111+
112+
self.vae = VAE(state_dim, action_dim, latent_dim, max_action).to(device)
113+
self.vae_optimizer = torch.optim.Adam(self.vae.parameters())
114+
115+
self.max_action = max_action
116+
self.action_dim = action_dim
117+
self.discount = discount
118+
self.tau = tau
119+
self.lmbda = lmbda
120+
121+
def select_action(self, state):
122+
with torch.no_grad():
123+
state = torch.FloatTensor(state.reshape(1, -1)).repeat(100, 1).to(device)
124+
action = self.actor(state, self.vae.decode(state))
125+
q1 = self.critic.q1(state, action)
126+
ind = q1.argmax(0)
127+
return action[ind].cpu().data.numpy().flatten()
128+
129+
def train(self, replay_buffer, batch_size=100):
130+
# Sample replay buffer / batch
131+
state, action, next_state, reward, not_done = replay_buffer.sample(batch_size)
132+
133+
# Variational Auto-Encoder Training
134+
recon, mean, std = self.vae(state, action)
135+
recon_loss = F.mse_loss(recon, action)
136+
KL_loss = -0.5 * (1 + torch.log(std.pow(2)) - mean.pow(2) - std.pow(2)).mean()
137+
vae_loss = recon_loss + 0.5 * KL_loss
138+
139+
self.vae_optimizer.zero_grad()
140+
vae_loss.backward()
141+
self.vae_optimizer.step()
142+
143+
# Reward Critic Training
144+
with torch.no_grad():
145+
# Duplicate next state 10 times
146+
next_state = torch.repeat_interleave(next_state, 10, 0)
147+
148+
# Compute value of perturbed actions sampled from the VAE
149+
target_Q1, target_Q2 = self.critic_target(next_state, self.actor_target(next_state, self.vae.decode(next_state)))
150+
151+
# Soft Clipped Double Q-learning
152+
target_Q = self.lmbda * torch.min(target_Q1, target_Q2) + (1. - self.lmbda) * torch.max(target_Q1, target_Q2)
153+
# Take max over each action sampled from the VAE
154+
target_Q = target_Q.reshape(batch_size, -1).max(1)[0].reshape(-1, 1)
155+
156+
target_Q = reward + not_done * self.discount * target_Q
157+
158+
current_Q1, current_Q2 = self.critic(state, action)
159+
critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q)
160+
161+
self.critic_optimizer.zero_grad()
162+
critic_loss.backward()
163+
self.critic_optimizer.step()
164+
165+
# Cost Critic Training
166+
with torch.no_grad():
167+
# Duplicate next state 10 times
168+
next_state = torch.repeat_interleave(next_state, 10, 0)
169+
170+
# Compute value of perturbed actions sampled from the VAE
171+
target_Q1, target_Q2 = self.critic_target(next_state, self.actor_target(next_state, self.vae.decode(next_state)))
172+
173+
# Soft Clipped Double Q-learning
174+
target_Q = self.lmbda * torch.min(target_Q1, target_Q2) + (1. - self.lmbda) * torch.max(target_Q1, target_Q2)
175+
# Take max over each action sampled from the VAE
176+
target_Q = target_Q.reshape(batch_size, -1).max(1)[0].reshape(-1, 1)
177+
178+
target_Q = reward + not_done * self.discount * target_Q
179+
180+
current_Q1, current_Q2 = self.critic(state, action)
181+
critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q)
182+
183+
self.critic_optimizer.zero_grad()
184+
critic_loss.backward()
185+
self.critic_optimizer.step()
186+
187+
# Pertubation Model / Action Training
188+
sampled_actions = self.vae.decode(state)
189+
perturbed_actions = self.actor(state, sampled_actions)
190+
191+
# Update through DPG
192+
actor_loss = -self.critic.q1(state, perturbed_actions).mean()
193+
194+
self.actor_optimizer.zero_grad()
195+
actor_loss.backward()
196+
self.actor_optimizer.step()
197+
198+
# Update Target Networks
199+
for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
200+
target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
201+
202+
for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
203+
target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)

0 commit comments

Comments
 (0)