Skip to content

Commit

Permalink
Fixed evaluation order and super problems, still not passing all tests
Browse files Browse the repository at this point in the history
  • Loading branch information
AryazE committed Feb 21, 2022
1 parent 08ca8ad commit 5e6be62
Show file tree
Hide file tree
Showing 21 changed files with 177,950 additions and 4,194 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ build/
dist/
.pytest_cache/
test/targetPrograms/*.orig
.hypothesis/
87 changes: 59 additions & 28 deletions build/lib/dynapyt/instrument/CodeInstrumenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from libcst.matchers import call_if_not_inside, call_if_inside
import libcst.helpers as helpers
from libcst.metadata.expression_context_provider import ExpressionContext
from libcst.metadata.scope_provider import QualifiedNameSource
from libcst.metadata.scope_provider import QualifiedNameSource, ClassScope
from numpy import isin


class CodeInstrumenter(m.MatcherDecoratableTransformer):
Expand All @@ -21,6 +22,7 @@ def __init__(self, src, file_path, iids, selected_hooks):
self.iids = iids
self.name_stack = []
self.current_try = []
self.current_class = []
self.selected_hooks = selected_hooks
self.to_import = set()

Expand All @@ -40,12 +42,30 @@ def __create_import(self, names):
stmt = cst.SimpleStatementLine(body=[imp])
return stmt

def __wrap_in_lambda(self, node):
# used_names = set(map(lambda x: x.value, m.findall(node, m.Name())))
def __wrap_in_lambda(self, original_node, updated_node):
if m.matches(updated_node, m.Call(func=m.Name('super'), args=[])):
class_arg = cst.Arg(value=cst.Name(value=self.current_class[-1]))
new_node = updated_node.with_changes(args=[class_arg, cst.Arg(value=cst.Name('self'))])
return cst.Lambda(params=cst.Parameters(params=[]), body=new_node)
used_names = list(m.findall(original_node, m.Name()))
unique_names = set()
parameters = []
# for n in used_names:
# parameters.append(cst.Param(name=cst.Name(value=n), default=cst.Name(value=n)))
lambda_expr = cst.Lambda(params=cst.Parameters(params=parameters), body=node)
try:
my_scope = self.get_metadata(ScopeProvider, original_node)
except KeyError:
my_scope = None
if isinstance(my_scope, ClassScope):
for n in used_names:
try:
name_source = self.get_metadata(QualifiedNameProvider, n)
n_scope = self.get_metadata(ScopeProvider, n)
except KeyError:
name_source = []
n_scope = None
if (n.value not in unique_names) and (my_scope == n_scope) and (len(list(name_source)) > 0) and (list(name_source)[0].source == QualifiedNameSource.LOCAL):
parameters.append(cst.Param(name=cst.Name(value=n.value), default=cst.Name(value=n.value)))
unique_names.add(n.value)
lambda_expr = cst.Lambda(params=cst.Parameters(params=parameters), body=updated_node)
return lambda_expr

def __as_string(self, s):
Expand All @@ -54,12 +74,19 @@ def __as_string(self, s):
else:
return '"' + s + '"'

def visit_Annotation(self, node):
return False

def visit_Decorator(self, node):
return False

def leave_Tuple(self, original_node, updated_node):
if len(updated_node.lpar) == 0:
return updated_node.with_changes(lpar=[cst.LeftParen()], rpar=[cst.RightParen()])
return updated_node

# Top level

def leave_Module(self, original_node, updated_node):
imports_index = -1
# '\"\"\"' + self.source.replace('\"', '\\"') + '\"\"\"'
Expand Down Expand Up @@ -112,6 +139,13 @@ def leave_Module(self, original_node, updated_node):
new_body = list(updated_node.body[:imports_index+1]) + dynapyt_imports + [get_ast] + [try_body]
return updated_node.with_changes(body=new_body)

def visit_ClassDef(self, node):
self.current_class.append(node.name.value)

def leave_ClassDef(self, original_node, updated_node):
self.current_class.pop()
return updated_node

def leave_Expr(self, original_node, updated_node):
if 'expression' not in self.selected_hooks:
return updated_node
Expand All @@ -120,7 +154,7 @@ def leave_Expr(self, original_node, updated_node):
iid = self.__create_iid(original_node)
ast_arg = cst.Arg(value=cst.Name('_dynapyt_ast_'))
iid_arg = cst.Arg(value=cst.Integer(value=str(iid)))
val_arg = cst.Arg(self.__wrap_in_lambda(original_node))
val_arg = cst.Arg(self.__wrap_in_lambda(original_node, updated_node))
call = cst.Call(func=callee_name, args=[ast_arg, iid_arg, val_arg])
return updated_node.with_changes(value=call)

Expand Down Expand Up @@ -153,7 +187,7 @@ def leave_Name(self, original_node, updated_node):
iid = self.__create_iid(original_node)
ast_arg = cst.Arg(value=cst.Name('_dynapyt_ast_'))
iid_arg = cst.Arg(value=cst.Integer(value=str(iid)))
var_arg = cst.Arg(value=self.__wrap_in_lambda(updated_node))
var_arg = cst.Arg(value=self.__wrap_in_lambda(original_node, updated_node))
call = cst.Call(func=callee_name, args=[ast_arg, iid_arg, var_arg])
return call
else:
Expand Down Expand Up @@ -254,7 +288,7 @@ def leave_Del(self, original_node, updated_node):
iid = self.__create_iid(original_node)
ast_arg = cst.Arg(value=cst.Name('_dynapyt_ast_'))
iid_arg = cst.Arg(value=cst.Integer(value=str(iid)))
target_arg = cst.Arg(value=self.__wrap_in_lambda(updated_node.target))
target_arg = cst.Arg(value=self.__wrap_in_lambda(original_node.target, updated_node.target))
call = cst.Call(func=callee_name, args=[ast_arg, iid_arg, target_arg])
return cst.Expr(value=call)

Expand Down Expand Up @@ -305,10 +339,10 @@ def leave_BinaryOperation(self, original_node, updated_node):
iid = self.__create_iid(original_node)
ast_arg = cst.Arg(value=cst.Name('_dynapyt_ast_'))
iid_arg = cst.Arg(value=cst.Integer(value=str(iid)))
left_arg = cst.Arg(updated_node.left)
left_arg = cst.Arg(self.__wrap_in_lambda(original_node.left, updated_node.left))
operator_name = type(original_node.operator).__name__
operator_arg = cst.Arg(cst.Integer(str(bin_op[operator_name])))
right_arg = cst.Arg(updated_node.right)
right_arg = cst.Arg(self.__wrap_in_lambda(original_node.right, updated_node.right))
call = cst.Call(func=callee_name, args=[
ast_arg, iid_arg, left_arg, operator_arg, right_arg])
return call
Expand All @@ -322,10 +356,10 @@ def leave_BooleanOperation(self, original_node, updated_node):
iid = self.__create_iid(original_node)
ast_arg = cst.Arg(value=cst.Name('_dynapyt_ast_'))
iid_arg = cst.Arg(value=cst.Integer(value=str(iid)))
left_arg = cst.Arg(updated_node.left)
left_arg = cst.Arg(self.__wrap_in_lambda(original_node.left, updated_node.left))
operator_name = type(original_node.operator).__name__
operator_arg = cst.Arg(cst.Integer(str(bool_op[operator_name])))
right_arg = cst.Arg(updated_node.right)
right_arg = cst.Arg(self.__wrap_in_lambda(original_node.right, updated_node.right))
call = cst.Call(func=callee_name, args=[
ast_arg, iid_arg, left_arg, operator_arg, right_arg])
return call
Expand All @@ -347,7 +381,7 @@ def leave_UnaryOperation(self, original_node, updated_node):
return call

def leave_Comparison(self, original_node, updated_node):
if 'comparison' not in self.selected_hooks:
if ('comparison' not in self.selected_hooks) and (not any(type(i.operator).__name__ in self.selected_hooks for i in updated_node.comparisons)):
return updated_node
comp_op = {'Equal': 0, 'GreaterThan': 1, 'GreaterThanEqual': 2, 'In': 3,
'Is': 4, 'LessThan': 5, 'LessThanEqual': 6, 'NotEqual': 7,
Expand All @@ -359,18 +393,12 @@ def leave_Comparison(self, original_node, updated_node):
iid_arg = cst.Arg(value=cst.Integer(value=str(iid)))
left_arg = cst.Arg(updated_node.left)
comparisons = []
instrument = False
for i in updated_node.comparisons:
operator_name = type(i.operator).__name__
if operator_name in self.selected_hooks:
instrument = True
comparisons.append(cst.Element(value=cst.Tuple(elements=[cst.Element(cst.Integer(str(comp_op[operator_name]))), cst.Element(i.comparator)])))
call = cst.Call(func=callee_name, args=[
ast_arg, iid_arg, left_arg, cst.Arg(cst.List(elements=comparisons))])
if instrument:
return call
else:
return updated_node
return call

def leave_Assign(self, original_node, updated_node):
if 'assignment' not in self.selected_hooks:
Expand All @@ -381,7 +409,7 @@ def leave_Assign(self, original_node, updated_node):
ast_arg = cst.Arg(value=cst.Name('_dynapyt_ast_'))
iid_arg = cst.Arg(value=cst.Integer(value=str(iid)))
val_arg = cst.Arg(value=updated_node.value)
left_arg = cst.Arg(value=cst.List(elements=[cst.Element(self.__wrap_in_lambda(t.target)) for t in updated_node.targets]))
left_arg = cst.Arg(value=cst.List(elements=[cst.Element(self.__wrap_in_lambda(tu.target, tu.target)) for to, tu in zip(original_node.targets, updated_node.targets)]))
call = cst.Call(func=callee_name, args=[ast_arg, iid_arg, val_arg, left_arg])
# new_targets = [t for t in original_node.targets if m.matches(t, m.AssignTarget(target=m.Name()))]
# old_targets = []
Expand All @@ -404,7 +432,7 @@ def leave_AugAssign(self, original_node, updated_node):
operator_name = type(original_node.operator).__name__
opr_arg = cst.Arg(value=cst.Integer(value=str(aug_op[operator_name])))
val_arg = cst.Arg(value=updated_node.value)
left_arg = cst.Arg(value=self.__wrap_in_lambda(updated_node.target))
left_arg = cst.Arg(value=self.__wrap_in_lambda(original_node.target, updated_node.target))
call = cst.Call(func=callee_name, args=[ast_arg, iid_arg, left_arg, opr_arg, val_arg])
return updated_node.with_changes(value=call, target=original_node.target)

Expand All @@ -419,10 +447,13 @@ def leave_FunctionDef(self, original_node, updated_node):
iid = self.__create_iid(original_node)
ast_arg = cst.Arg(value=cst.Name('_dynapyt_ast_'))
iid_arg = cst.Arg(value=cst.Integer(value=str(iid)))
args_arg = cst.Arg(value=cst.List(elements=[cst.Element(value=self.__wrap_in_lambda(p.name)) for p in updated_node.params.params]))
args_arg = cst.Arg(value=cst.List(elements=[cst.Element(value=self.__wrap_in_lambda(po.name, pu.name)) for po, pu in zip(original_node.params.params, updated_node.params.params)]))
entry_stmt = cst.Expr(cst.Call(func=enter_name, args=[ast_arg, iid_arg, args_arg]))
exit_stmt = cst.Expr(cst.Call(func=exit_name, args=[ast_arg, iid_arg]))
new_body = updated_node.body.with_changes(body=[cst.SimpleStatementLine([entry_stmt])]+list(updated_node.body.body)+[cst.SimpleStatementLine([exit_stmt])])
if m.matches(updated_node.body.body[0], m.SimpleStatementLine(body=[m.Expr(value=m.SimpleString())])):
new_body = updated_node.body.with_changes(body=[updated_node.body.body[0], cst.SimpleStatementLine([entry_stmt])]+list(updated_node.body.body[1:])+[cst.SimpleStatementLine([exit_stmt])])
else:
new_body = updated_node.body.with_changes(body=[cst.SimpleStatementLine([entry_stmt])]+list(updated_node.body.body)+[cst.SimpleStatementLine([exit_stmt])])
new_node = updated_node
return new_node.with_changes(body=new_body)

Expand All @@ -434,8 +465,8 @@ def leave_Lambda(self, original_node, updated_node):
iid = self.__create_iid(original_node)
ast_arg = cst.Arg(value=cst.Name('_dynapyt_ast_'))
iid_arg = cst.Arg(value=cst.Integer(value=str(iid)))
args_arg = cst.Arg(value=cst.List(elements=[cst.Element(value=self.__wrap_in_lambda(p.name)) for p in updated_node.params.params]))
body_arg = cst.Arg(value=self.__wrap_in_lambda(updated_node.body))
args_arg = cst.Arg(value=cst.List(elements=[cst.Element(value=self.__wrap_in_lambda(po.name, pu.name)) for po, pu in zip(original_node.params.params, updated_node.params.params)]))
body_arg = cst.Arg(value=self.__wrap_in_lambda(original_node.body, updated_node.body))
new_stmt = cst.Call(func=callee_name, args=[ast_arg, iid_arg, args_arg, body_arg])
return updated_node.with_changes(body=new_stmt)

Expand Down Expand Up @@ -483,7 +514,7 @@ def leave_Call(self, original_node, updated_node):
iid = self.__create_iid(original_node)
ast_arg = cst.Arg(value=cst.Name('_dynapyt_ast_'))
iid_arg = cst.Arg(value=cst.Integer(value=str(iid)))
call_arg = cst.Arg(value=self.__wrap_in_lambda(updated_node))
call_arg = cst.Arg(value=self.__wrap_in_lambda(original_node, updated_node))
call = cst.Call(func=callee_name, args=[ast_arg, iid_arg, call_arg])
return call

Expand Down
30 changes: 15 additions & 15 deletions build/lib/dynapyt/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,35 +82,35 @@ def _binary_op_(dyn_ast, iid, left, opr, right):
'LeftShift', 'MatrixMultiply', 'Modulo', 'Multiply', 'Power',
'RightShift', 'Subtract', 'And', 'Or']
if opr == 0:
result = left + right
result = left() + right()
elif opr == 1:
result = left & right
result = left() & right()
elif opr == 2:
result = left | right
result = left() | right()
elif opr == 3:
result = left ^ right
result = left() ^ right()
elif opr == 4:
result = left / right
result = left() / right()
elif opr == 5:
result = left // right
result = left() // right()
elif opr == 6:
result = left << right
result = left() << right()
elif opr == 7:
result = left @ right
result = left() @ right()
elif opr == 8:
result = left % right
result = left() % right()
elif opr == 9:
result = left * right
result = left() * right()
elif opr == 10:
result = left ** right
result = left() ** right()
elif opr == 11:
result = left >> right
result = left() >> right()
elif opr == 12:
result = left - right
result = left() - right()
elif opr == 13:
result = left and right
result = left() and right()
elif opr == 14:
result = left or right
result = left() or right()
result_high = call_if_exists('binary_op', dyn_ast, iid, bin_op[opr], left, right, result)
result_low = call_if_exists(snake(bin_op[opr]), dyn_ast, iid, left, right, result)
if result_low != None:
Expand Down
Loading

0 comments on commit 5e6be62

Please sign in to comment.