forked from fimrie/DEVELOP
-
Notifications
You must be signed in to change notification settings - Fork 13
/
data_augmentation.py
executable file
·310 lines (296 loc) · 16.2 KB
/
data_augmentation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
from utils import *
from copy import deepcopy
import random
# Generate the mask based on the valences and adjacent matrix so far
# For a (node_in_focus, neighbor, edge_type) to be valid, neighbor's color < 2 and
# there is no edge so far between node_in_focus and neighbor and it satisfy the valence constraint
# and node_in_focus != neighbor
def generate_mask(valences, adj_mat, color, real_n_vertices, node_in_focus, check_overlap_edge, new_mol):
edge_type_mask=[]
edge_mask=[]
for neighbor in range(real_n_vertices):
if neighbor != node_in_focus and color[neighbor] < 2 and \
not check_adjacent_sparse(adj_mat, node_in_focus, neighbor)[0]:
min_valence = min(valences[node_in_focus], valences[neighbor], 3)
# Check whether two cycles have more than two overlap edges here
# the neighbor color = 1 and there are left valences and
# adding that edge will not cause overlap edges.
if check_overlap_edge and min_valence > 0 and color[neighbor] == 1:
# attempt to add the edge
new_mol.AddBond(int(node_in_focus), int(neighbor), number_to_bond[0])
# Check whether there are two cycles having more than two overlap edges
ssr = Chem.GetSymmSSSR(new_mol)
overlap_flag = False
for idx1 in range(len(ssr)):
for idx2 in range(idx1+1, len(ssr)):
if len(set(ssr[idx1]) & set(ssr[idx2])) > 2:
overlap_flag=True
# remove that edge
new_mol.RemoveBond(int(node_in_focus), int(neighbor))
if overlap_flag:
continue
for v in range(min_valence):
assert v < 3
edge_type_mask.append((node_in_focus, neighbor, v))
# there might be an edge between node in focus and neighbor
if min_valence > 0:
edge_mask.append((node_in_focus, neighbor))
return edge_type_mask, edge_mask
# when a new edge is about to be added, we generate labels based on ground truth
# if an edge is in ground truth and has not been added to incremental adj yet, we label it as positive
def generate_label(ground_truth_graph, incremental_adj, node_in_focus, real_neighbor, real_n_vertices, params):
edge_type_label=[]
edge_label=[]
for neighbor in range(real_n_vertices):
adjacent, edge_type = check_adjacent_sparse(ground_truth_graph, node_in_focus, neighbor)
incre_adjacent, incre_edge_type = check_adjacent_sparse(incremental_adj, node_in_focus, neighbor)
if not params["label_one_hot"] and adjacent and not incre_adjacent:
assert edge_type < 3
edge_type_label.append((node_in_focus, neighbor, edge_type))
edge_label.append((node_in_focus, neighbor))
elif params["label_one_hot"] and adjacent and not incre_adjacent and neighbor==real_neighbor:
edge_type_label.append((node_in_focus, neighbor, edge_type))
edge_label.append((node_in_focus, neighbor))
return edge_type_label, edge_label
# add a incremental adj with one new edge
def genereate_incremental_adj(last_adj, node_in_focus, neighbor, edge_type):
# copy last incremental adj matrix
new_adj= deepcopy(last_adj)
# Add a new edge into it
new_adj[node_in_focus].append((neighbor, edge_type))
new_adj[neighbor].append((node_in_focus, edge_type))
return new_adj
def update_one_step(overlapped_edge_features, distance_to_others,node_sequence, node_in_focus, neighbor, edge_type, edge_type_masks, valences, incremental_adj_mat,
color, real_n_vertices, graph, edge_type_labels, local_stop, edge_masks, edge_labels, local_stop_label, params,
check_overlap_edge, new_mol, up_to_date_adj_mat,keep_prob):
# check whether to keep this transition or not
if params["sample_transition"] and random.random()> keep_prob:
return
# record the current node in focus
node_sequence.append(node_in_focus)
# generate mask based on current situation
edge_type_mask, edge_mask=generate_mask(valences, up_to_date_adj_mat,
color,real_n_vertices, node_in_focus, check_overlap_edge, new_mol)
edge_type_masks.append(edge_type_mask)
edge_masks.append(edge_mask)
if not local_stop_label:
# generate the label based on ground truth graph
edge_type_label, edge_label=generate_label(graph, up_to_date_adj_mat, node_in_focus, neighbor,real_n_vertices, params)
edge_type_labels.append(edge_type_label)
edge_labels.append(edge_label)
else:
edge_type_labels.append([])
edge_labels.append([])
# update local stop
local_stop.append(local_stop_label)
# Calculate distance using bfs from the current node to all other node
distances = bfs_distance(node_in_focus, up_to_date_adj_mat)
distances = [(start, node, params["truncate_distance"]) if d > params["truncate_distance"] else (start, node, d) for start, node, d in distances]
distance_to_others.append(distances)
# Calculate the overlapped edge mask
overlapped_edge_features.append(get_overlapped_edge_feature(edge_mask, color, new_mol))
# update the incremental adj mat at this step
incremental_adj_mat.append(deepcopy(up_to_date_adj_mat))
def construct_incremental_graph(dataset, edges, max_n_vertices, real_n_vertices, node_symbol, params, is_training_data, initial_idx=0):
# Avoid calculating this if it is just for generating new molecules for speeding up
if params["generation"] and is_training_data:
return [], [], [], [], [], [], [], [], []
# avoid the initial index is larger than real_n_vertices:
if initial_idx >= real_n_vertices:
initial_idx=0
# Maximum valences for each node
valences=get_initial_valence([np.argmax(symbol) for symbol in node_symbol], dataset)
# Add backward edges
edges_bw=[(dst, edge_type, src) for src, edge_type, dst in edges]
edges=edges+edges_bw
# Construct a graph object using the edges
graph=defaultdict(list)
for src, edge_type, dst in edges:
graph[src].append((dst, edge_type))
# Breadth first search over the molecule
# color 0: have not found 1: in the queue 2: searched already
color = [0] * max_n_vertices
color[initial_idx] = 1
queue=deque([initial_idx])
# create a adj matrix without any edges
up_to_date_adj_mat=defaultdict(list)
# record incremental adj mat
incremental_adj_mat=[]
# record the distance to other nodes at the moment
distance_to_others=[]
# soft constraint on overlapped edges
overlapped_edge_features=[]
# the exploration order of the nodes
node_sequence=[]
# edge type masks for nn predictions at each step
edge_type_masks=[]
# edge type labels for nn predictions at each step
edge_type_labels=[]
# edge masks for nn predictions at each step
edge_masks=[]
# edge labels for nn predictions at each step
edge_labels=[]
# local stop labels
local_stop=[]
# record the incremental molecule
new_mol = Chem.MolFromSmiles('')
new_mol = Chem.rdchem.RWMol(new_mol)
# Add atoms
add_atoms(new_mol, sample_node_symbol([node_symbol], [len(node_symbol)], dataset)[0], dataset)
# calculate keep probability
sample_transition_count= real_n_vertices + len(edges)/2
keep_prob= float(sample_transition_count)/((real_n_vertices + len(edges)/2) * params["bfs_path_count"]) # to form a binomial distribution
while len(queue) > 0:
node_in_focus=queue.popleft()
current_adj_list=graph[node_in_focus]
# sort (canonical order) it or shuffle (random order) it
if not params["path_random_order"]:
current_adj_list=sorted(current_adj_list)
else:
random.shuffle(current_adj_list)
for neighbor, edge_type in current_adj_list:
# Add this edge if the color of neighbor node is not 2
if color[neighbor]<2:
update_one_step(overlapped_edge_features, distance_to_others,node_sequence, node_in_focus, neighbor, edge_type,
edge_type_masks, valences, incremental_adj_mat, color, real_n_vertices, graph,
edge_type_labels, local_stop, edge_masks, edge_labels, False, params, params["check_overlap_edge"], new_mol,
up_to_date_adj_mat,keep_prob)
# Add the edge and obtain a new adj mat
up_to_date_adj_mat=genereate_incremental_adj(
up_to_date_adj_mat, node_in_focus, neighbor, edge_type)
# suppose the edge is selected and update valences after adding the
valences[node_in_focus]-=(edge_type + 1)
valences[neighbor]-=(edge_type + 1)
# update the incremental mol
new_mol.AddBond(int(node_in_focus), int(neighbor), number_to_bond[edge_type])
# Explore neighbor nodes
if color[neighbor]==0:
queue.append(neighbor)
color[neighbor]=1
# local stop here. We move on to another node for exploration or stop completely
update_one_step(overlapped_edge_features, distance_to_others,node_sequence, node_in_focus, None, None, edge_type_masks,
valences, incremental_adj_mat, color, real_n_vertices, graph,
edge_type_labels, local_stop, edge_masks, edge_labels, True, params, params["check_overlap_edge"], new_mol, up_to_date_adj_mat,keep_prob)
color[node_in_focus]=2
return incremental_adj_mat,distance_to_others,node_sequence,edge_type_masks,edge_type_labels,local_stop, edge_masks, edge_labels, overlapped_edge_features
def construct_incremental_graph_preselected(dataset, edges, max_n_vertices, real_n_vertices, vertices_to_keep, exit_points, node_symbol, params, is_training_data, initial_idx=0, single_exit=False):
# Avoid calculating this if it is just for generating new molecules for speeding up
if params["generation"] and is_training_data:
return [], [], [], [], [], [], [], [], []
# avoid the initial index is larger than real_n_vertices:
if initial_idx >= real_n_vertices:
initial_idx=0
# Maximum valences for each node
valences=get_initial_valence([np.argmax(symbol) for symbol in node_symbol], dataset)
# Construct a graph object using the edges
graph=defaultdict(list)
for src, edge_type, dst in edges:
graph[src].append((dst, edge_type))
# Breadth first search over the molecule
# color 0: have not found 1: in the queue 2: searched already
color = [0] * max_n_vertices
#color[initial_idx] = 1
#queue=deque([initial_idx])
queue=deque([]) # Queue starts as empty. Initial idx only added if nothing else kept
# create a adj matrix without any edges
up_to_date_adj_mat=defaultdict(list)
# record incremental adj mat
incremental_adj_mat=[]
# record the distance to other nodes at the moment
distance_to_others=[]
# soft constraint on overlapped edges
overlapped_edge_features=[]
# the exploration order of the nodes
node_sequence=[]
# edge type masks for nn predictions at each step
edge_type_masks=[]
# edge type labels for nn predictions at each step
edge_type_labels=[]
# edge masks for nn predictions at each step
edge_masks=[]
# edge labels for nn predictions at each step
edge_labels=[]
# local stop labels
local_stop=[]
# record the incremental molecule
new_mol = Chem.MolFromSmiles('')
new_mol = Chem.rdchem.RWMol(new_mol)
# Add atoms
add_atoms(new_mol, sample_node_symbol([node_symbol], [len(node_symbol)], dataset)[0], dataset)
# calculate keep probability
sample_transition_count= real_n_vertices + len(edges)/2
keep_prob= float(sample_transition_count)/((real_n_vertices + len(edges)/2) * params["bfs_path_count"]) # to form a binomial distribution
# Add edges between vertices in vertices_to_keep and update starting points
edges_to_ignore = []
for src, edge_type, dst in edges:
if src in vertices_to_keep and dst in vertices_to_keep:
# Add the edge and obtain a new adj mat
up_to_date_adj_mat=genereate_incremental_adj(
up_to_date_adj_mat, src, dst, edge_type)
# update valences
valences[src]-=(edge_type + 1)
valences[dst]-=(edge_type + 1)
# update the incremental mol
new_mol.AddBond(int(src), int(dst), number_to_bond[edge_type])
# add edge to ignore list
edges_to_ignore.append((src, edge_type, dst))
edges_to_ignore.append((dst, edge_type, src))
# If single exit, only add one of the exit vectors to the queue
if single_exit:
exit_to_use = [random.choice(exit_points)]
else:
exit_to_use = exit_points
# Add vertices in vertices_to_keep to queue and update colour
for v in vertices_to_keep:
if v in exit_points:
if v in exit_to_use:
queue.append(v)
color[v]=1
else:
# Mask out nodes that aren't exit vectors
valences[v] = 0
color[v] = 2
# Add initial idx to queue if no vertices to keep
if len(queue) == 0:
queue.append(initial_idx)
color[initial_idx] = 1
# Add backward edges
edges_bw=[(dst, edge_type, src) for src, edge_type, dst in edges]
edges=edges+edges_bw
# Construct a graph object using the edges
graph=defaultdict(list)
for src, edge_type, dst in edges:
graph[src].append((dst, edge_type))
while len(queue) > 0:
node_in_focus=queue.popleft()
current_adj_list=graph[node_in_focus]
# sort (canonical order) it or shuffle (random order) it
if not params["path_random_order"]:
current_adj_list=sorted(current_adj_list)
else:
random.shuffle(current_adj_list)
for neighbor, edge_type in current_adj_list:
# Add this edge if the edge is not in edges to ignore and if the color of neighbor node is not 2
if (node_in_focus, edge_type, neighbor) not in edges_to_ignore and color[neighbor]<2:
update_one_step(overlapped_edge_features, distance_to_others,node_sequence, node_in_focus, neighbor, edge_type,
edge_type_masks, valences, incremental_adj_mat, color, real_n_vertices, graph,
edge_type_labels, local_stop, edge_masks, edge_labels, False, params, params["check_overlap_edge"], new_mol,
up_to_date_adj_mat,keep_prob)
# Add the edge and obtain a new adj mat
up_to_date_adj_mat=genereate_incremental_adj(
up_to_date_adj_mat, node_in_focus, neighbor, edge_type)
# suppose the edge is selected and update valences after adding the
valences[node_in_focus]-=(edge_type + 1)
valences[neighbor]-=(edge_type + 1)
# update the incremental mol
new_mol.AddBond(int(node_in_focus), int(neighbor), number_to_bond[edge_type])
# Explore neighbor nodes
if color[neighbor]==0:
queue.append(neighbor)
color[neighbor]=1
# local stop here. We move on to another node for exploration or stop completely
update_one_step(overlapped_edge_features, distance_to_others,node_sequence, node_in_focus, None, None, edge_type_masks,
valences, incremental_adj_mat, color, real_n_vertices, graph,
edge_type_labels, local_stop, edge_masks, edge_labels, True, params, params["check_overlap_edge"], new_mol, up_to_date_adj_mat,keep_prob)
color[node_in_focus]=2
return incremental_adj_mat,distance_to_others,node_sequence,edge_type_masks,edge_type_labels,local_stop, edge_masks, edge_labels, overlapped_edge_features