Skip to content

Commit

Permalink
Fix old function usage
Browse files Browse the repository at this point in the history
  • Loading branch information
hermansje committed Mar 11, 2019
1 parent a7be07b commit 57977ff
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 10 deletions.
5 changes: 2 additions & 3 deletions example/example_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,18 @@
import asttokens

class PythonAst:
def _get_text(self, code):
def get_text(self, code):
atok = asttokens.ASTTokens(code, tree = self)
return atok.get_text(self)

def patch_ast():
# forces class to use PythonAst as a mixin
# allows them to use _get_text method
# allows them to use get_text method
for obj in ast.__dict__.values():
if inspect.isclass(obj) and issubclass(obj, ast.AST):
if obj == ast.AST or PythonAst in obj.__bases__: continue
obj.__bases__ = obj.__bases__ + (PythonAst, )
obj._priority = 0
obj._get_field_names = lambda self: self._fields

# add AstNode, and ParseError classes for Dispatcher
ast.AstNode = PythonAst
Expand Down
4 changes: 2 additions & 2 deletions protowhat/checks/check_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def get_text(ast, code):
if isinstance(ast, ParseError):
return code
try:
return ast._get_text(code)
return ast.get_text(code)
except:
return code

Expand Down Expand Up @@ -300,7 +300,7 @@ def get_str(ast, code, sql):
if isinstance(ast, str):
return ast
try:
return ast._get_text(code)
return ast.get_text(code)
except:
return None

Expand Down
2 changes: 1 addition & 1 deletion protowhat/utils_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ class AstNode(AST):
_fields = []
_priority = 1

def _get_text(self, text):
def get_text(self, text):
raise NotImplemented()

def get_position(self):
Expand Down
6 changes: 2 additions & 4 deletions tests/test_check_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
# TODO: selectors require a _priority attribute and _get_field_names
# this is a holdover from the sql ast modules
ast.Expr._priority = 0
ast.Module._get_field_names = lambda self: self._fields
ast.Expr._get_field_names = lambda self: self._fields
DUMMY_NODES = {'Expr': ast.Expr}

class ParseHey:
Expand All @@ -37,7 +35,7 @@ def state():
solution_code = "",
reporter = Reporter(),
# args below should be ignored
pre_exercise_code = "NA",
pre_exercise_code = "NA",
student_result = "", solution_result = "",
student_conn = None, solution_conn = None,
ast_dispatcher = Dispatcher(DUMMY_NODES, ParseHey())
Expand Down Expand Up @@ -83,7 +81,7 @@ def code_state():
solution_code = {'script1.py': '3 + 3', 'script2.py': '4 + 4'},
reporter = Reporter(),
# args below should be ignored
pre_exercise_code = "NA",
pre_exercise_code = "NA",
student_result = "", solution_result = "",
student_conn = None, solution_conn = None,
ast_dispatcher = Dispatcher(DUMMY_NODES, ParseHey())
Expand Down

0 comments on commit 57977ff

Please sign in to comment.