-
Notifications
You must be signed in to change notification settings - Fork 72
/
Copy pathkg_env.py
258 lines (219 loc) · 11.2 KB
/
kg_env.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
from __future__ import absolute_import, division, print_function
import os
import sys
from tqdm import tqdm
import pickle
import random
import torch
from datetime import datetime
from knowledge_graph import KnowledgeGraph
from utils import *
class KGState(object):
def __init__(self, embed_size, history_len=1):
self.embed_size = embed_size
self.history_len = history_len # mode: one of {full, current}
if history_len == 0:
self.dim = 2 * embed_size
elif history_len == 1:
self.dim = 4 * embed_size
elif history_len == 2:
self.dim = 6 * embed_size
else:
raise Exception('history length should be one of {0, 1, 2}')
def __call__(self, user_embed, node_embed, last_node_embed, last_relation_embed, older_node_embed,
older_relation_embed):
if self.history_len == 0:
return np.concatenate([user_embed, node_embed])
elif self.history_len == 1:
return np.concatenate([user_embed, node_embed, last_node_embed, last_relation_embed])
elif self.history_len == 2:
return np.concatenate([user_embed, node_embed, last_node_embed, last_relation_embed, older_node_embed,
older_relation_embed])
else:
raise Exception('mode should be one of {full, current}')
class BatchKGEnvironment(object):
def __init__(self, dataset_str, max_acts, max_path_len=3, state_history=1):
self.max_acts = max_acts
self.act_dim = max_acts + 1 # Add self-loop action, whose act_idx is always 0.
self.max_num_nodes = max_path_len + 1 # max number of hops (= #nodes - 1)
self.kg = load_kg(dataset_str)
self.embeds = load_embed(dataset_str)
self.embed_size = self.embeds[USER].shape[1]
self.embeds[SELF_LOOP] = (np.zeros(self.embed_size), 0.0)
self.state_gen = KGState(self.embed_size, history_len=state_history)
self.state_dim = self.state_gen.dim
# Compute user-product scores for scaling.
u_p_scores = np.dot(self.embeds[USER] + self.embeds[PURCHASE][0], self.embeds[PRODUCT].T)
self.u_p_scales = np.max(u_p_scores, axis=1)
# Compute path patterns
self.patterns = []
for pattern_id in [1, 11, 12, 13, 14, 15, 16, 17, 18]:
pattern = PATH_PATTERN[pattern_id]
pattern = [SELF_LOOP] + [v[0] for v in pattern[1:]] # pattern contains all relations
if pattern_id == 1:
pattern.append(SELF_LOOP)
self.patterns.append(tuple(pattern))
# Following is current episode information.
self._batch_path = None # list of tuples of (relation, node_type, node_id)
self._batch_curr_actions = None # save current valid actions
self._batch_curr_state = None
self._batch_curr_reward = None
# Here only use 1 'done' indicator, since all paths have same length and will finish at the same time.
self._done = False
def _has_pattern(self, path):
pattern = tuple([v[0] for v in path])
return pattern in self.patterns
def _batch_has_pattern(self, batch_path):
return [self._has_pattern(path) for path in batch_path]
def _get_actions(self, path, done):
"""Compute actions for current node."""
_, curr_node_type, curr_node_id = path[-1]
actions = [(SELF_LOOP, curr_node_id)] # self-loop must be included.
# (1) If game is finished, only return self-loop action.
if done:
return actions
# (2) Get all possible edges from original knowledge graph.
# [CAVEAT] Must remove visited nodes!
relations_nodes = self.kg(curr_node_type, curr_node_id)
candidate_acts = [] # list of tuples of (relation, node_type, node_id)
visited_nodes = set([(v[1], v[2]) for v in path])
for r in relations_nodes:
next_node_type = KG_RELATION[curr_node_type][r]
next_node_ids = relations_nodes[r]
next_node_ids = [n for n in next_node_ids if (next_node_type, n) not in visited_nodes] # filter
candidate_acts.extend(zip([r] * len(next_node_ids), next_node_ids))
# (3) If candidate action set is empty, only return self-loop action.
if len(candidate_acts) == 0:
return actions
# (4) If number of available actions is smaller than max_acts, return action sets.
if len(candidate_acts) <= self.max_acts:
candidate_acts = sorted(candidate_acts, key=lambda x: (x[0], x[1]))
actions.extend(candidate_acts)
return actions
# (5) If there are too many actions, do some deterministic trimming here!
user_embed = self.embeds[USER][path[0][-1]]
scores = []
for r, next_node_id in candidate_acts:
next_node_type = KG_RELATION[curr_node_type][r]
if next_node_type == USER:
src_embed = user_embed
elif next_node_type == PRODUCT:
src_embed = user_embed + self.embeds[PURCHASE][0]
elif next_node_type == WORD:
src_embed = user_embed + self.embeds[MENTION][0]
else: # BRAND, CATEGORY, RELATED_PRODUCT
src_embed = user_embed + self.embeds[PURCHASE][0] + self.embeds[r][0]
score = np.matmul(src_embed, self.embeds[next_node_type][next_node_id])
# This trimming may filter out target products!
# Manually set the score of target products a very large number.
# if next_node_type == PRODUCT and next_node_id in self._target_pids:
# score = 99999.0
scores.append(score)
candidate_idxs = np.argsort(scores)[-self.max_acts:] # choose actions with larger scores
candidate_acts = sorted([candidate_acts[i] for i in candidate_idxs], key=lambda x: (x[0], x[1]))
actions.extend(candidate_acts)
return actions
def _batch_get_actions(self, batch_path, done):
return [self._get_actions(path, done) for path in batch_path]
def _get_state(self, path):
"""Return state of numpy vector: [user_embed, curr_node_embed, last_node_embed, last_relation]."""
user_embed = self.embeds[USER][path[0][-1]]
zero_embed = np.zeros(self.embed_size)
if len(path) == 1: # initial state
state = self.state_gen(user_embed, user_embed, zero_embed, zero_embed, zero_embed, zero_embed)
return state
older_relation, last_node_type, last_node_id = path[-2]
last_relation, curr_node_type, curr_node_id = path[-1]
curr_node_embed = self.embeds[curr_node_type][curr_node_id]
last_node_embed = self.embeds[last_node_type][last_node_id]
last_relation_embed, _ = self.embeds[last_relation] # this can be self-loop!
if len(path) == 2:
state = self.state_gen(user_embed, curr_node_embed, last_node_embed, last_relation_embed, zero_embed,
zero_embed)
return state
_, older_node_type, older_node_id = path[-3]
older_node_embed = self.embeds[older_node_type][older_node_id]
older_relation_embed, _ = self.embeds[older_relation]
state = self.state_gen(user_embed, curr_node_embed, last_node_embed, last_relation_embed, older_node_embed,
older_relation_embed)
return state
def _batch_get_state(self, batch_path):
batch_state = [self._get_state(path) for path in batch_path]
return np.vstack(batch_state) # [bs, dim]
def _get_reward(self, path):
# If it is initial state or 1-hop search, reward is 0.
if len(path) <= 2:
return 0.0
if not self._has_pattern(path):
return 0.0
target_score = 0.0
_, curr_node_type, curr_node_id = path[-1]
if curr_node_type == PRODUCT:
# Give soft reward for other reached products.
uid = path[0][-1]
u_vec = self.embeds[USER][uid] + self.embeds[PURCHASE][0]
p_vec = self.embeds[PRODUCT][curr_node_id]
score = np.dot(u_vec, p_vec) / self.u_p_scales[uid]
target_score = max(score, 0.0)
return target_score
def _batch_get_reward(self, batch_path):
batch_reward = [self._get_reward(path) for path in batch_path]
return np.array(batch_reward)
def _is_done(self):
"""Episode ends only if max path length is reached."""
return self._done or len(self._batch_path[0]) >= self.max_num_nodes
def reset(self, uids=None):
if uids is None:
all_uids = list(self.kg(USER).keys())
uids = [random.choice(all_uids)]
# each element is a tuple of (relation, entity_type, entity_id)
self._batch_path = [[(SELF_LOOP, USER, uid)] for uid in uids]
self._done = False
self._batch_curr_state = self._batch_get_state(self._batch_path)
self._batch_curr_actions = self._batch_get_actions(self._batch_path, self._done)
self._batch_curr_reward = self._batch_get_reward(self._batch_path)
return self._batch_curr_state
def batch_step(self, batch_act_idx):
"""
Args:
batch_act_idx: list of integers.
Returns:
batch_next_state: numpy array of size [bs, state_dim].
batch_reward: numpy array of size [bs].
done: True/False
"""
assert len(batch_act_idx) == len(self._batch_path)
# Execute batch actions.
for i in range(len(batch_act_idx)):
act_idx = batch_act_idx[i]
_, curr_node_type, curr_node_id = self._batch_path[i][-1]
relation, next_node_id = self._batch_curr_actions[i][act_idx]
if relation == SELF_LOOP:
next_node_type = curr_node_type
else:
next_node_type = KG_RELATION[curr_node_type][relation]
self._batch_path[i].append((relation, next_node_type, next_node_id))
self._done = self._is_done() # must run before get actions, etc.
self._batch_curr_state = self._batch_get_state(self._batch_path)
self._batch_curr_actions = self._batch_get_actions(self._batch_path, self._done)
self._batch_curr_reward = self._batch_get_reward(self._batch_path)
return self._batch_curr_state, self._batch_curr_reward, self._done
def batch_action_mask(self, dropout=0.0):
"""Return action masks of size [bs, act_dim]."""
batch_mask = []
for actions in self._batch_curr_actions:
act_idxs = list(range(len(actions)))
if dropout > 0 and len(act_idxs) >= 5:
keep_size = int(len(act_idxs[1:]) * (1.0 - dropout))
tmp = np.random.choice(act_idxs[1:], keep_size, replace=False).tolist()
act_idxs = [act_idxs[0]] + tmp
act_mask = np.zeros(self.act_dim, dtype=np.uint8)
act_mask[act_idxs] = 1
batch_mask.append(act_mask)
return np.vstack(batch_mask)
def print_path(self):
for path in self._batch_path:
msg = 'Path: {}({})'.format(path[0][1], path[0][2])
for node in path[1:]:
msg += ' =={}=> {}({})'.format(node[0], node[1], node[2])
print(msg)