-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathenvelope_q_learning.py
414 lines (378 loc) · 19.2 KB
/
envelope_q_learning.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric.data as gd
from torch import Tensor
from torch_scatter import scatter
from gflownet.config import Config
from gflownet.envs.graph_building_env import (
GraphActionCategorical,
GraphBuildingEnv,
GraphBuildingEnvContext,
generate_forward_trajectory,
)
from gflownet.models.graph_transformer import GraphTransformer, mlp
from gflownet.trainer import GFNTask
from .graph_sampling import GraphSampler
# Custom models are necessary for envelope Q Learning
class GraphTransformerFragEnvelopeQL(nn.Module):
"""GraphTransformer class for an EnvelopeQLearning agent
Outputs Qs for the following actions
- Stop
- AddNode
- SetEdgeAttr
"""
def __init__(self, env_ctx, num_emb=64, num_layers=3, num_heads=2, num_objectives=2):
super().__init__()
self.transf = GraphTransformer(
x_dim=env_ctx.num_node_dim,
e_dim=env_ctx.num_edge_dim,
g_dim=env_ctx.num_cond_dim,
num_emb=num_emb,
num_layers=num_layers,
num_heads=num_heads,
)
num_final = num_emb * 2
num_mlp_layers = 0
self.emb2add_node = mlp(num_final, num_emb, env_ctx.num_new_node_values * num_objectives, num_mlp_layers)
# Edge attr logits are "sided", so we will compute both sides independently
self.emb2set_edge_attr = mlp(
num_emb + num_final, num_emb, env_ctx.num_edge_attr_logits // 2 * num_objectives, num_mlp_layers
)
self.emb2stop = mlp(num_emb * 3, num_emb, num_objectives, num_mlp_layers)
self.emb2reward = mlp(num_emb * 3, num_emb, 1, num_mlp_layers)
self.edge2emb = mlp(num_final, num_emb, num_emb, num_mlp_layers)
self.logZ = mlp(env_ctx.num_cond_dim, num_emb * 2, 1, 2)
self.action_type_order = env_ctx.action_type_order
self.mask_value = -10
self.num_objectives = num_objectives
def forward(self, g: gd.Batch, cond: torch.Tensor, output_Qs=False):
"""See `GraphTransformer` for argument values"""
node_embeddings, graph_embeddings = self.transf(g, cond)
# On `::2`, edges are duplicated to make graphs undirected, only take the even ones
e_row, e_col = g.edge_index[:, ::2]
edge_emb = self.edge2emb(node_embeddings[e_row] + node_embeddings[e_col])
src_anchor_logits = self.emb2set_edge_attr(torch.cat([edge_emb, node_embeddings[e_row]], 1))
dst_anchor_logits = self.emb2set_edge_attr(torch.cat([edge_emb, node_embeddings[e_col]], 1))
def _mask(x, m):
# mask logit vector x with binary mask m
return x * m + self.mask_value * (1 - m)
def _mask_obj(x, m):
# mask logit vector x with binary mask m
return (
x.reshape(x.shape[0], x.shape[1] // self.num_objectives, self.num_objectives) * m[:, :, None]
+ self.mask_value * (1 - m[:, :, None])
).reshape(x.shape)
cat = GraphActionCategorical(
g,
logits=[
F.relu(self.emb2stop(graph_embeddings)),
_mask(F.relu(self.emb2add_node(node_embeddings)), g.add_node_mask),
_mask_obj(F.relu(torch.cat([src_anchor_logits, dst_anchor_logits], 1)), g.set_edge_attr_mask),
],
keys=[None, "x", "edge_index"],
types=self.action_type_order,
)
r_pred = self.emb2reward(graph_embeddings)
if output_Qs:
return cat, r_pred
cat.masks = [1, g.add_node_mask.cpu(), g.set_edge_attr_mask.cpu()]
# Compute the greedy policy
# See algo.envelope_q_learning.EnvelopeQLearning.compute_batch_losses for further explanations
# TODO: this makes assumptions about how conditional vectors are created! Not robust to upstream changes
w = cond[:, -self.num_objectives :]
w_dot_Q = [
(qi.reshape((qi.shape[0], qi.shape[1] // w.shape[1], w.shape[1])) * w[b][:, None, :]).sum(2)
for qi, b in zip(cat.logits, cat.batch)
]
# Set the softmax distribution to a very low temperature to make sure only the max gets
# sampled (and we get random argmax tie breaking for free!):
cat.logits = [i * 100 for i in w_dot_Q]
return cat, r_pred
class GraphTransformerEnvelopeQL(nn.Module):
def __init__(self, env_ctx, num_emb=64, num_layers=3, num_heads=2, num_objectives=2):
"""See `GraphTransformer` for argument values"""
super().__init__()
self.transf = GraphTransformer(
x_dim=env_ctx.num_node_dim,
e_dim=env_ctx.num_edge_dim,
g_dim=env_ctx.num_cond_dim,
num_emb=num_emb,
num_layers=num_layers,
num_heads=num_heads,
)
num_final = num_emb * 2
num_mlp_layers = 0
self.emb2add_edge = mlp(num_final, num_emb, num_objectives, num_mlp_layers)
self.emb2add_node = mlp(num_final, num_emb, env_ctx.num_new_node_values * num_objectives, num_mlp_layers)
self.emb2set_node_attr = mlp(num_final, num_emb, env_ctx.num_node_attr_logits * num_objectives, num_mlp_layers)
self.emb2set_edge_attr = mlp(num_final, num_emb, env_ctx.num_edge_attr_logits * num_objectives, num_mlp_layers)
self.emb2stop = mlp(num_emb * 3, num_emb, num_objectives, num_mlp_layers)
self.emb2reward = mlp(num_emb * 3, num_emb, 1, num_mlp_layers)
self.logZ = mlp(env_ctx.num_cond_dim, num_emb * 2, 1, 2)
self.action_type_order = env_ctx.action_type_order
self.num_objectives = num_objectives
def forward(self, g: gd.Batch, cond: torch.Tensor, output_Qs=False):
node_embeddings, graph_embeddings = self.transf(g, cond)
ne_row, ne_col = g.non_edge_index
# On `::2`, edges are duplicated to make graphs undirected, only take the even ones
e_row, e_col = g.edge_index[:, ::2]
cat = GraphActionCategorical(
g,
logits=[
self.emb2stop(graph_embeddings),
self.emb2add_node(node_embeddings),
self.emb2set_node_attr(node_embeddings),
self.emb2add_edge(node_embeddings[ne_row] + node_embeddings[ne_col]),
self.emb2set_edge_attr(node_embeddings[e_row] + node_embeddings[e_col]),
],
keys=[None, "x", "x", "non_edge_index", "edge_index"],
types=self.action_type_order,
)
r_pred = self.emb2reward(graph_embeddings)
if output_Qs:
return cat, r_pred
# Compute the greedy policy
# See algo.envelope_q_learning.EnvelopeQLearning.compute_batch_losses for further explanations
# TODO: this makes assumptions about how conditional vectors are created! Not robust to upstream changes
w = cond[:, -self.num_objectives :]
w_dot_Q = [
(qi.reshape((qi.shape[0], qi.shape[1] // w.shape[1], w.shape[1])) * w[b][:, None, :]).sum(2)
for qi, b in zip(cat.logits, cat.batch)
]
# Set the softmax distribution to a very low temperature to make sure only the max gets
# sampled (and we get random argmax tie breaking for free!):
cat.logits = [i * 100 for i in w_dot_Q]
return cat, r_pred
class EnvelopeQLearning:
def __init__(
self,
env: GraphBuildingEnv,
ctx: GraphBuildingEnvContext,
task: GFNTask,
rng: np.random.RandomState,
cfg: Config,
):
"""Envelope Q-Learning implementation, see
A Generalized Algorithm for Multi-Objective Reinforcement Learning and Policy Adaptation,
Runzhe Yang, Xingyuan Sun, Karthik Narasimhan,
NeurIPS 2019,
https://arxiv.org/abs/1908.08342
Hyperparameters used:
illegal_action_logreward: float, log(R) given to the model for non-sane end states or illegal actions
Parameters
----------
env: GraphBuildingEnv
A graph environment.
ctx: GraphBuildingEnvContext
A context.
rng: np.random.RandomState
rng used to take random actions
cfg: Config
The experiment configuration
"""
self.ctx = ctx
self.env = env
self.task = task
self.rng = rng
self.max_len = cfg.algo.max_len
self.max_nodes = cfg.algo.max_nodes
self.illegal_action_logreward = cfg.algo.illegal_action_logreward
self.gamma = cfg.algo.moql.gamma
self.num_objectives = cfg.algo.moql.num_objectives
self.num_omega_samples = cfg.algo.moql.num_omega_samples
self.lambda_decay = cfg.algo.moql.lambda_decay
self.invalid_penalty = cfg.algo.moql.penalty
self._num_updates = 0
assert self.gamma == 1
self.bootstrap_own_reward = False
# Experimental flags
self.sample_temp = 1
self.do_q_prime_correction = False
self.graph_sampler = GraphSampler(ctx, env, self.max_len, self.max_nodes, rng, self.sample_temp)
def create_training_data_from_own_samples(
self, model: nn.Module, n: int, cond_info: Tensor, random_action_prob: float
):
"""Generate trajectories by sampling a model
Parameters
----------
model: nn.Module
The model being sampled
graphs: List[Graph]
List of N Graph endpoints
cond_info: torch.tensor
Conditional information, shape (N, n_info)
random_action_prob: float
Probability of taking a random action
Returns
-------
data: List[Dict]
A list of trajectories. Each trajectory is a dict with keys
- trajs: List[Tuple[Graph, GraphAction]]
- fwd_logprob: log Z + sum logprobs P_F
- bck_logprob: sum logprobs P_B
- is_valid: is the generated graph valid according to the env & ctx
"""
dev = self.ctx.device
cond_info = cond_info.to(dev)
data = self.graph_sampler.sample_from_model(model, n, cond_info, dev, random_action_prob)
return data
def create_training_data_from_graphs(self, graphs):
"""Generate trajectories from known endpoints
Parameters
----------
graphs: List[Graph]
List of Graph endpoints
Returns
-------
trajs: List[Dict{'traj': List[tuple[Graph, GraphAction]]}]
A list of trajectories.
"""
return [{"traj": generate_forward_trajectory(i)} for i in graphs]
def construct_batch(self, trajs, cond_info, log_rewards):
"""Construct a batch from a list of trajectories and their information
Parameters
----------
trajs: List[List[tuple[Graph, GraphAction]]]
A list of N trajectories.
cond_info: Tensor
The conditional info that is considered for each trajectory. Shape (N, n_info)
log_rewards: Tensor
The transformed log-reward (e.g. torch.log(R(x) ** beta) ) for each trajectory. Shape (N,)
Returns
-------
batch: gd.Batch
A (CPU) Batch object with relevant attributes added
"""
torch_graphs = [self.ctx.graph_to_Data(i[0]) for tj in trajs for i in tj["traj"]]
actions = [
self.ctx.GraphAction_to_aidx(g, a) for g, a in zip(torch_graphs, [i[1] for tj in trajs for i in tj["traj"]])
]
batch = self.ctx.collate(torch_graphs)
batch.traj_lens = torch.tensor([len(i["traj"]) for i in trajs])
batch.actions = torch.tensor(actions)
batch.log_rewards = log_rewards
batch.cond_info = cond_info
batch.is_valid = torch.tensor([i.get("is_valid", True) for i in trajs]).float()
# Now we create a duplicate/repeated batch for Q(s,a,w')
omega_prime = self.task.sample_conditional_information(self.num_omega_samples * batch.num_graphs)
torch_graphs = [i for i in torch_graphs for j in range(self.num_omega_samples)]
actions = [i for i in actions for j in range(self.num_omega_samples)]
batch_prime = self.ctx.collate(torch_graphs)
batch_prime.traj_lens = batch.traj_lens.repeat_interleave(self.num_omega_samples)
batch_prime.actions = torch.tensor(actions)
batch_prime.cond_info = omega_prime["encoding"]
batch_prime.preferences = omega_prime["preferences"]
batch.batch_prime = batch_prime
return batch
def compute_batch_losses(self, model: nn.Module, batch: gd.Batch, num_bootstrap: int = 0):
"""Compute the losses over trajectories contained in the batch
Parameters
----------
model: TrajectoryBalanceModel
A GNN taking in a batch of graphs as input as per constructed by `self.construct_batch`.
Must have a `logZ` attribute, itself a model, which predicts log of Z(cond_info)
batch: gd.Batch
batch of graphs inputs as per constructed by `self.construct_batch`
num_bootstrap: int
the number of trajectories for which the reward loss is computed. Ignored if 0."""
dev = batch.x.device
# A single trajectory is comprised of many graphs
num_trajs = int(batch.traj_lens.shape[0])
cond_info = batch.cond_info
num_objectives = self.num_objectives
# This index says which trajectory each graph belongs to, so
# it will look like [0,0,0,0,1,1,1,2,...] if trajectory 0 is
# of length 4, trajectory 1 of length 3, and so on.
batch_idx = torch.arange(num_trajs, device=dev).repeat_interleave(batch.traj_lens)
num_states = batch_idx.shape[0]
# The position of the last graph of each trajectory
final_graph_idx = torch.cumsum(batch.traj_lens, 0) - 1
# Forward pass of the model, returns a GraphActionCategorical and per molecule predictions
# Here we will interpret the logits of the fwd_cat as Q values
# Q(s,a,omega)
fwd_cat, per_state_preds = model(batch, cond_info[batch_idx], output_Qs=True)
Q_omega = fwd_cat.logits
# reshape to List[shape: (num <T> in all graphs, num actions on T, num_objectives) | for all types T]
Q_omega = [i.reshape((i.shape[0], i.shape[1] // num_objectives, num_objectives)) for i in Q_omega]
# We do the same for omega', Q(s, a, w')
batchp = batch.batch_prime
batchp_num_trajs = int(batchp.traj_lens.shape[0])
batchp_batch_idx = torch.arange(batchp_num_trajs, device=dev).repeat_interleave(batchp.traj_lens)
fwd_cat_prime, per_state_preds = model(batchp, batchp.cond_info[batchp_batch_idx], output_Qs=True)
Q_omega_prime = fwd_cat_prime.logits
# We've repeated everything N_omega times, so we can reshape the same way as above but with
# an extra N_omega first dimension
Q_omega_prime = [i.reshape((i.shape[0], i.shape[1] // num_objectives, num_objectives)) for i in Q_omega_prime]
# The math is
# y = r + \arg_Q \max_{a,w'} w . Q(s', a, w')
# so we're going to compute all the dot products
# then take
w = batch.preferences[batch_idx] # shape: (num_graphs, num_objectives)
w_dot_Q = [
# Broadcast preferences w over the actions axis, then sum.
# What's going on with the indexing here?
# here qi has shape (N_omega * num objects, num actions, num objectives)
# w has shape (num graphs, num objectives), and we index it by b
# b has shape (num objects), and refers to which graph each object entry corresponds to in fwd_cat.
# For use in Q_omega_prime, we repeat every state N_omega times,
# therefore, we repeat_interleave b so that each repeated state has its own w copy
(
qi
* (
w[b.repeat_interleave(self.num_omega_samples)].reshape(
(self.num_omega_samples * b.shape[0], 1, num_objectives)
)
)
# then we multiply the Q(s, a, w') and w, take the sum on the right axis for the dot
).sum(2)
for qi, b in zip(Q_omega_prime, fwd_cat.batch)
] # List[shape: (N_omega * num objects, num actions)]
# Now we need to do an argmax, over actions _and_ omegas, of the dot which has shape
# List[(N_omega * num objects, num actions).] To do this fortunately we can reuse the
# mechanisms of GraphActionCategorical. Once again we repeat interleave the batch indices
# (of fwd_cat) to map back to the original states -- this makes it so that what are
# considered different states (because they are repeated) in batch_prime are considered the
# same (and thus the max is over all of the repeats as well).
# Since the batch slices we will later index to get Q[:, argmax a, argmax omega'] are those
# of Q_omega_prime, we need to use fwd_cat_prime.
argmax = fwd_cat_prime.argmax(
x=w_dot_Q, batch=[b.repeat_interleave(self.num_omega_samples) for b in fwd_cat.batch], dim_size=num_states
)
# Now what we want, for each state, is the vector prediction made by Q(s, a, w') for the
# argmax a,w'. Let's again reuse GraphActionCategorical methods to do the indexing for us.
# We must again use fwd_cat_prime to use the right slices.
Q_pareto = fwd_cat_prime.log_prob(actions=argmax, logprobs=Q_omega_prime)
# shape: (num_graphs, num_objectives)
# Now we have \arg_Q \max_{a,w'} w . Q(s, a, w') for each state, we really want Q(s', ...)
# Shift t+1-> t, pad last state with a 0, multiply by gamma
shifted_Q_pareto = self.gamma * torch.cat([Q_pareto[1:], torch.zeros_like(Q_pareto[:1])], dim=0)
# Replace Q(s_T) with R(tau). Since we've shifted the values in the array, Q(s_T) is Q(s_0)
# of the next trajectory in the array, and rewards are terminal (0 except at s_T).
shifted_Q_pareto[final_graph_idx] = batch.flat_rewards + ((1 - batch.is_valid) * self.invalid_penalty)[:, None]
y = shifted_Q_pareto.detach()
# We now use the same log_prob trick to get Q(s,a,w)
Q_saw = fwd_cat.log_prob(actions=batch.actions, logprobs=Q_omega)
# and compute L_A
loss_A = (y - Q_saw).pow(2).sum(1)
# and L_B
loss_B = abs((w * y).sum(1) - (w * Q_saw).sum(1))
Lambda = 1 - self.lambda_decay / (self.lambda_decay + self._num_updates)
losses = (1 - Lambda) * loss_A + Lambda * loss_B
self._num_updates += 1
traj_losses = scatter(losses, batch_idx, dim=0, dim_size=num_trajs, reduce="sum")
loss = losses.mean()
invalid_mask = 1 - batch.is_valid
info = {
"loss": loss.item(),
"loss_A": loss_A.mean(),
"loss_B": loss_B.mean(),
"offline_loss": traj_losses[: batch.num_offline].mean() if batch.num_offline > 0 else 0,
"online_loss": traj_losses[batch.num_offline :].mean() if batch.num_online > 0 else 0,
"invalid_trajectories": invalid_mask.sum() / batch.num_online if batch.num_online > 0 else 0,
"invalid_losses": (invalid_mask * traj_losses).sum() / (invalid_mask.sum() + 1e-4),
}
if not torch.isfinite(traj_losses).all():
raise ValueError("loss is not finite")
return loss, info