-
Notifications
You must be signed in to change notification settings - Fork 0
/
model_operators.py
335 lines (313 loc) · 12.2 KB
/
model_operators.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
import inspect, ast, astor
from ast import *
class ChangeClassName(ast.NodeTransformer):
def __init__(self, new_name):
self.new_name = new_name
super().__init__()
def visit_ClassDef(self, node):
self.generic_visit(node)
newnode = node
newnode.name = self.new_name
return newnode
class ChangeFunctionName(ast.NodeTransformer):
def __init__(self, new_name):
self.new_name = new_name
super().__init__()
def visit_FunctionDef(self, node):
self.generic_visit(node)
newnode = node
newnode.name = self.new_name
return newnode
#class AddToFunctionBody(ast.NodeTransformer):
# """
# Adds code to either the beginning (default) or end of a function
# """
# def __init__(self, code, head=True, skip_lines = 0):
# self.code = code
# self.head = head
# self.skip_lines = skip_lines
# super().__init__()
# def visit_FunctionDef(self, node):
# self.generic_visit(node)
# if self.head:
# #insert the whole code statement first in the body of the function
# #node.body = [self.code] + node.body
# node.body = node.body[:self.skip_lines] + [self.code] + node.body[self.skip_lines:]
# else:
# #append the whole code statement last in the body of the function
# node.body.append(self.code)
# # switch code and return statement so it's last
# node.body[-1], node.body[-2] = node.body[-2], node.body[-1]
# return node
class AddToFunctionBody(ast.NodeTransformer):
"""
Adds code to either the beginning (default) or end of a function
"""
def __init__(self, code, pos=0):
self.code = code
self.pos = pos
super().__init__()
def visit_FunctionDef(self, node):
self.generic_visit(node)
if not type(self.code) == list:
self.code = [self.code]
for line in self.code:
node.body = node.body[:self.pos] + [line] + node.body[self.pos:]
if self.pos != -1:
self.pos += 1
return node
class AddToPlate(ast.NodeTransformer):
"""
Identifies plate with specified name, and appends specified code to its body.
If code is None, assigns plate to self.plate and deletes it
"""
def __init__(self, plate_name, code, pos=0):
self.plate_name = plate_name
self.code = code
self.pos = pos
super().__init__()
def visit_With(self, node):
# we want to visit child nodes, so visit it
self.generic_visit(node)
withexpr = node.items[0].context_expr
# assumes that plate names are on the form f'K_{_id}'
if withexpr.func.attr == 'plate' and withexpr.args[0].values[0].s == self.plate_name+'_':
self.plate = node
if self.code is None:
return
else:
if not type(self.code) == list:
self.code = [self.code]
for line in self.code:
node.body = node.body[:self.pos] + [line] + node.body[self.pos:]
if self.pos != -1:
self.pos += 1
return node
else:
return node
class GetPlateIndex(ast.NodeTransformer):
"""
Get the index of the sought plate in the body-list of whatever ast object (function or with.plate) contains it.
"""
def __init__(self, plate_name):
self.plate_name = plate_name
self.pos = -1
self.container = None
def visit_With(self, node):
self.generic_visit(node)
withexpr = node.items[0].context_expr
# assumes that plate names are on the form f'K_{_id}'
#print(f"Looking in plate {(withexpr.args[0].values[0].s[:-1])}")
if withexpr.func.attr == 'plate':
for i,elem in enumerate(node.body):
# if type(elem) == ast.With:
# print(f'Found plate {elem.items[0].context_expr.args[0].values[0].s}')
if (type(elem) == ast.With and elem.items[0].context_expr.args[0].values[0].s == self.plate_name+'_'):
self.pos = i
self.container = withexpr.args[0].values[0].s[:-1]
return node
else:
return node
def visit_FunctionDef(self, node):
self.generic_visit(node)
for i,elem in enumerate(node.body):
if (type(elem) == ast.With and elem.items[0].context_expr.args[0].values[0].s == self.plate_name+'_'):
self.pos = i
self.container = ''
return node
class AddToForLoop(ast.NodeTransformer):
"""
Identifies for loop with specified indexing variable, and appends specified code to its body.
"""
def __init__(self, target_indexing_variable, code):
self.target_indexing_variable = target_indexing_variable
self.code = code
super().__init__()
def visit_For(self, node):
# we want to visit child nodes, so visit it
self.generic_visit(node)
indexing_variable = node.target.id
if indexing_variable == self.target_indexing_variable:
if self.code is None:
newnode = node
newnode.body = []
return newnode
else:
newnode = node
newnode.body.append(self.code)
return newnode
else:
return node
class AddReturn(ast.NodeTransformer):
"""
Takes a single or a list of AST objects and adds a corresponding return statement to the tree
"""
def __init__(self, return_tuple):
# if there are several ast objects to be returned
if hasattr(return_tuple, '__iter__'):
self.return_tuple = Return(value=Tuple(elts=[Name(id=elem,ctx=Load()) for elem in return_tuple],ctx=Load()))
# otherwise there's just one
else:
self.return_tuple = Return(value=return_tuple)
super().__init__()
def visit_FunctionDef(self, node):
# we want to visit child nodes, so visit it
self.generic_visit(node)
node.body.append(self.return_tuple)
return node
class AddArgsToFunctionDef(ast.NodeTransformer):
"""
Adds arguments to either the end (default) or beginning of a function definition
"""
def __init__(self, arg_to_add, pos = 0):
self.pos = pos
self.arg_to_add = arg_to_add
super().__init__()
def visit_FunctionDef(self, node):
self.generic_visit(node)
new_arg = arg(arg=self.arg_to_add, annotation=None)
node.args.args = node.args.args[:self.pos] + [new_arg] + node.args.args[self.pos:]
return node
class DeletePlate(ast.NodeTransformer):
"""
Identifies plate with specified name, and appends specified code to its body.
If code is None, assigns plate to self.plate and deletes it
"""
def __init__(self, plate_name, code):
self.plate_name = plate_name
super().__init__()
def visit_With(self, node):
# we want to visit child nodes, so visit it
self.generic_visit(node)
withexpr = node.items[0].context_expr
# assumes that plate names are on the form f'K_{_id}'
if withexpr.func.attr == 'plate' and withexpr.args[0].values[0].s.startswith(self.plate_name):
self.code = node
return
else:
return node
class CutFromPlate(ast.NodeTransformer):
"""
Identifies plate with specified name, and cuts first or last statement form its body
If code is None, assigns plate to self.plate and deletes it
"""
def __init__(self, plate_name, head = True):
self.plate_name = plate_name
self.head = head
super().__init__()
def visit_With(self, node):
# we want to visit child nodes, so visit it
self.generic_visit(node)
withexpr = node.items[0].context_expr
# assumes that plate names are on the form f'K_{_id}'
if withexpr.func.attr == 'plate' and withexpr.args[0].values[0].s.startswith(self.plate_name):
if self.head:
#cut the first element in body
self.code = node.body[0]
node.body = node.body[1:]
else:
#cut the last element in body
self.code = node.body[-1]
node.body = node.body[:-1]
return node
class CutArgsFromFunctionDef(ast.NodeTransformer):
"""
Cuts arguments from either the beginning (default) or end of a function definition
"""
def __init__(self, n_args_to_cut = 1, head = True):
self.n_args_to_cut = n_args_to_cut
self.head = head
super().__init__()
def visit_FunctionDef(self, node):
self.generic_visit(node)
if self.head:
node.args.args = node.args.args[self.n_args_to_cut:]
else:
node.args.args = node.args.args[:-self.n_args_to_cut]
return node
class CutFromFunctionBody(ast.NodeTransformer):
"""
Cuts code from either the beginning (default) or end of a function
"""
def __init__(self, head=True):
self.head = head
super().__init__()
def visit_FunctionDef(self, node):
self.generic_visit(node)
if self.head:
#cut the first element in body
self.code = node.body[0]
node.body = node.body[1:]
else:
#cut the last element in body
self.code = node.body[-1]
node.body = node.body[:-1]
return node
class GetObservationModel(ast.NodeTransformer):
"""
Identifies sampling site with name 'obs', and returns its parameters
"""
def __init__(self):
super().__init__()
def visit_Call(self, node):
# we want to visit child nodes, so visit it
self.generic_visit(node)
if node.func.attr == 'sample' and node.args[0].s == 'obs':
self.code = node
#self.code = node.args[1]
return node
else:
return node
class GetNames(ast.NodeTransformer):
"""
Returns all ast Names
This includes things like torch from 'torch.zeros(D)'
which can be avoided by naming it in the model, e.g. 'loc = torch.zeros(D)'
and using loc in the observation model
"""
def __init__(self):
super().__init__()
self.code = []
def visit_Name(self, node):
# we want to visit child nodes, so visit it
self.generic_visit(node)
self.code.append(node)
return node
class GetDistributionParameters(ast.NodeTransformer):
"""
Finds a call to dist.<something> and adds all its parameters to self.code
"""
def __init__(self):
super().__init__()
self.code = []
def visit_Call(self, node):
# we want to visit child nodes, so visit it
self.generic_visit(node)
if node.func.value.id == 'dist':
for arg in node.args:
self.code.append(arg.id)
for keyword in node.keywords:
self.code.append(keyword.value.id)
return node
class ChangeObservationModel(ast.NodeTransformer):
"""
Identifies sampling site with name 'obs', and replaces its distribution with a specified one
"""
def __init__(self, new_obs_model):
self.new_obs_model = new_obs_model
super().__init__()
def visit_Call(self, node):
# we want to visit child nodes, so visit it
self.generic_visit(node)
if node.func.attr == 'sample' and node.args[0].s == 'obs':
newnode = node
newnode.args[1] = self.new_obs_model
return newnode
else:
return node
def change_observation_model_to_LowRankMultivariateNormal(tree):
lowrank_normal_obs_model = ast.Call(func=ast.Attribute(value=ast.Name(id='dst', ctx=ast.Load()), attr='LowRankMultivariateNormal', ctx=ast.Load()),\
args=[ast.Name(id='loc', ctx=ast.Load())],\
keywords=[ast.keyword(arg='cov_factor', value=ast.Name(id='cov_factor', ctx=ast.Load())),\
ast.keyword(arg='cov_diag', value=ast.Name(id='cov_diag', ctx=ast.Load()))])
ChangeObservationModel(lowrank_normal_obs_model).visit(tree)