1
+ import copy
2
+ import numpy as np
3
+ import torch
4
+ import torch .nn as nn
5
+ import torch .nn .functional as F
6
+
7
+ device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
8
+
9
+
10
+ class Actor (nn .Module ):
11
+ def __init__ (self , state_dim , action_dim , max_action , hidden_unit = 256 , phi = 0.05 ):
12
+ super (Actor , self ).__init__ ()
13
+ self .l1 = nn .Linear (state_dim + action_dim , hidden_unit )
14
+ self .l2 = nn .Linear (hidden_unit , hidden_unit )
15
+ self .l3 = nn .Linear (hidden_unit , action_dim )
16
+
17
+ self .max_action = max_action
18
+ self .phi = phi
19
+
20
+ def forward (self , state , action ):
21
+ a = F .relu (self .l1 (torch .cat ([state , action ], 1 )))
22
+ a = F .relu (self .l2 (a ))
23
+ a = self .phi * self .max_action * torch .tanh (self .l3 (a ))
24
+ return (a + action ).clamp (- self .max_action , self .max_action )
25
+
26
+
27
+ class Critic (nn .Module ):
28
+ def __init__ (self , state_dim , action_dim , hidden_unit = 256 ):
29
+ super (Critic , self ).__init__ ()
30
+ self .l1 = nn .Linear (state_dim + action_dim , hidden_unit )
31
+ self .l2 = nn .Linear (hidden_unit , hidden_unit )
32
+ self .l3 = nn .Linear (hidden_unit , 1 )
33
+
34
+ self .l4 = nn .Linear (state_dim + action_dim , hidden_unit )
35
+ self .l5 = nn .Linear (hidden_unit , hidden_unit )
36
+ self .l6 = nn .Linear (hidden_unit , 1 )
37
+
38
+ def forward (self , state , action ):
39
+ q1 = F .relu (self .l1 (torch .cat ([state , action ], 1 )))
40
+ q1 = F .relu (self .l2 (q1 ))
41
+ q1 = self .l3 (q1 )
42
+
43
+ q2 = F .relu (self .l4 (torch .cat ([state , action ], 1 )))
44
+ q2 = F .relu (self .l5 (q2 ))
45
+ q2 = self .l6 (q2 )
46
+ return q1 , q2
47
+
48
+ def q1 (self , state , action ):
49
+ q1 = F .relu (self .l1 (torch .cat ([state , action ], 1 )))
50
+ q1 = F .relu (self .l2 (q1 ))
51
+ q1 = self .l3 (q1 )
52
+ return q1
53
+
54
+
55
+ # Vanilla Variational Auto-Encoder
56
+ class VAE (nn .Module ):
57
+ def __init__ (self , state_dim , action_dim , latent_dim , max_action , hidden_unit = 256 ):
58
+ super (VAE , self ).__init__ ()
59
+ self .e1 = nn .Linear (state_dim + action_dim , hidden_unit )
60
+ self .e2 = nn .Linear (hidden_unit , hidden_unit )
61
+
62
+ self .mean = nn .Linear (hidden_unit , latent_dim )
63
+ self .log_std = nn .Linear (hidden_unit , latent_dim )
64
+
65
+ self .d1 = nn .Linear (state_dim + latent_dim , hidden_unit )
66
+ self .d2 = nn .Linear (hidden_unit , hidden_unit )
67
+ self .d3 = nn .Linear (hidden_unit , action_dim )
68
+
69
+ self .max_action = max_action
70
+ self .latent_dim = latent_dim
71
+
72
+ def forward (self , state , action ):
73
+ z = F .relu (self .e1 (torch .cat ([state , action ], 1 )))
74
+ z = F .relu (self .e2 (z ))
75
+
76
+ mean = self .mean (z )
77
+ # Clamped for numerical stability
78
+ log_std = self .log_std (z ).clamp (- 4 , 15 )
79
+ std = torch .exp (log_std )
80
+ z = mean + std * torch .randn_like (std )
81
+
82
+ u = self .decode (state , z )
83
+
84
+ return u , mean , std
85
+
86
+ def decode (self , state , z = None ):
87
+ # When sampling from the VAE, the latent vector is clipped to [-0.5, 0.5]
88
+ if z is None :
89
+ z = torch .randn ((state .shape [0 ], self .latent_dim )).to (device ).clamp (- 0.5 , 0.5 )
90
+
91
+ a = F .relu (self .d1 (torch .cat ([state , z ], 1 )))
92
+ a = F .relu (self .d2 (a ))
93
+ return self .max_action * torch .tanh (self .d3 (a ))
94
+
95
+
96
+ class BCQ (object ):
97
+ def __init__ (self , state_dim , action_dim , max_action , discount = 0.99 , tau = 0.005 , lmbda = 0.75 , phi = 0.05 ):
98
+ latent_dim = action_dim * 2
99
+
100
+ self .actor = Actor (state_dim , action_dim , max_action , phi = phi ).to (device )
101
+ self .actor_target = copy .deepcopy (self .actor )
102
+ self .actor_optimizer = torch .optim .Adam (self .actor .parameters (), lr = 1e-3 )
103
+
104
+ self .critic = Critic (state_dim , action_dim ).to (device )
105
+ self .critic_target = copy .deepcopy (self .critic )
106
+ self .critic_optimizer = torch .optim .Adam (self .critic .parameters (), lr = 1e-3 )
107
+
108
+ self .cost_critic = Critic (state_dim , action_dim ).to (device )
109
+ self .cost_critic_target = copy .deepcopy (self .cost_critic )
110
+ self .cost_critic_optimizer = torch .optim .Adam (self .critic .parameters (), lr = 1e-3 )
111
+
112
+ self .vae = VAE (state_dim , action_dim , latent_dim , max_action ).to (device )
113
+ self .vae_optimizer = torch .optim .Adam (self .vae .parameters ())
114
+
115
+ self .max_action = max_action
116
+ self .action_dim = action_dim
117
+ self .discount = discount
118
+ self .tau = tau
119
+ self .lmbda = lmbda
120
+
121
+ def select_action (self , state ):
122
+ with torch .no_grad ():
123
+ state = torch .FloatTensor (state .reshape (1 , - 1 )).repeat (100 , 1 ).to (device )
124
+ action = self .actor (state , self .vae .decode (state ))
125
+ q1 = self .critic .q1 (state , action )
126
+ ind = q1 .argmax (0 )
127
+ return action [ind ].cpu ().data .numpy ().flatten ()
128
+
129
+ def train (self , replay_buffer , batch_size = 100 ):
130
+ # Sample replay buffer / batch
131
+ state , action , next_state , reward , not_done = replay_buffer .sample (batch_size )
132
+
133
+ # Variational Auto-Encoder Training
134
+ recon , mean , std = self .vae (state , action )
135
+ recon_loss = F .mse_loss (recon , action )
136
+ KL_loss = - 0.5 * (1 + torch .log (std .pow (2 )) - mean .pow (2 ) - std .pow (2 )).mean ()
137
+ vae_loss = recon_loss + 0.5 * KL_loss
138
+
139
+ self .vae_optimizer .zero_grad ()
140
+ vae_loss .backward ()
141
+ self .vae_optimizer .step ()
142
+
143
+ # Reward Critic Training
144
+ with torch .no_grad ():
145
+ # Duplicate next state 10 times
146
+ next_state = torch .repeat_interleave (next_state , 10 , 0 )
147
+
148
+ # Compute value of perturbed actions sampled from the VAE
149
+ target_Q1 , target_Q2 = self .critic_target (next_state , self .actor_target (next_state , self .vae .decode (next_state )))
150
+
151
+ # Soft Clipped Double Q-learning
152
+ target_Q = self .lmbda * torch .min (target_Q1 , target_Q2 ) + (1. - self .lmbda ) * torch .max (target_Q1 , target_Q2 )
153
+ # Take max over each action sampled from the VAE
154
+ target_Q = target_Q .reshape (batch_size , - 1 ).max (1 )[0 ].reshape (- 1 , 1 )
155
+
156
+ target_Q = reward + not_done * self .discount * target_Q
157
+
158
+ current_Q1 , current_Q2 = self .critic (state , action )
159
+ critic_loss = F .mse_loss (current_Q1 , target_Q ) + F .mse_loss (current_Q2 , target_Q )
160
+
161
+ self .critic_optimizer .zero_grad ()
162
+ critic_loss .backward ()
163
+ self .critic_optimizer .step ()
164
+
165
+ # Cost Critic Training
166
+ with torch .no_grad ():
167
+ # Duplicate next state 10 times
168
+ next_state = torch .repeat_interleave (next_state , 10 , 0 )
169
+
170
+ # Compute value of perturbed actions sampled from the VAE
171
+ target_Q1 , target_Q2 = self .critic_target (next_state , self .actor_target (next_state , self .vae .decode (next_state )))
172
+
173
+ # Soft Clipped Double Q-learning
174
+ target_Q = self .lmbda * torch .min (target_Q1 , target_Q2 ) + (1. - self .lmbda ) * torch .max (target_Q1 , target_Q2 )
175
+ # Take max over each action sampled from the VAE
176
+ target_Q = target_Q .reshape (batch_size , - 1 ).max (1 )[0 ].reshape (- 1 , 1 )
177
+
178
+ target_Q = reward + not_done * self .discount * target_Q
179
+
180
+ current_Q1 , current_Q2 = self .critic (state , action )
181
+ critic_loss = F .mse_loss (current_Q1 , target_Q ) + F .mse_loss (current_Q2 , target_Q )
182
+
183
+ self .critic_optimizer .zero_grad ()
184
+ critic_loss .backward ()
185
+ self .critic_optimizer .step ()
186
+
187
+ # Pertubation Model / Action Training
188
+ sampled_actions = self .vae .decode (state )
189
+ perturbed_actions = self .actor (state , sampled_actions )
190
+
191
+ # Update through DPG
192
+ actor_loss = - self .critic .q1 (state , perturbed_actions ).mean ()
193
+
194
+ self .actor_optimizer .zero_grad ()
195
+ actor_loss .backward ()
196
+ self .actor_optimizer .step ()
197
+
198
+ # Update Target Networks
199
+ for param , target_param in zip (self .critic .parameters (), self .critic_target .parameters ()):
200
+ target_param .data .copy_ (self .tau * param .data + (1 - self .tau ) * target_param .data )
201
+
202
+ for param , target_param in zip (self .actor .parameters (), self .actor_target .parameters ()):
203
+ target_param .data .copy_ (self .tau * param .data + (1 - self .tau ) * target_param .data )
0 commit comments