Skip to content

Commit

Permalink
Make dispatcher work with dynamic ast nodes
Browse files Browse the repository at this point in the history
Update dispatcher tests
sqlwhat already has tests for this functionality on a SQL grammar
  • Loading branch information
hermansje committed Mar 11, 2019
1 parent 7d8b69d commit 6dff876
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 13 deletions.
37 changes: 27 additions & 10 deletions protowhat/selectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@


class Selector(NodeVisitor):
def __init__(self, src, priority=None):
def __init__(self, src, src_name=None, strict=True, priority=None):
self.src = src
self.src_name = src_name
self.strict = strict
self.priority = src._priority if priority is None else priority
self.out = []

Expand All @@ -23,29 +25,44 @@ def visit_list(self, lst):
self.visit(item)

def is_match(self, node):
if type(node) is self.src:
return True
if self.strict:
if type(node) is self.src:
return True
else:
return False
else:
return False
if isinstance(node, self.src) and (
self.src_name is None or self.src_name == node.__class__.__name__
):
return True
else:
return False

def has_priority_over(self, node):
return self.priority > node._priority


class Dispatcher:
def __init__(self, nodes, ast=None, safe_parsing=True):
def __init__(self, node_cls, nodes=None, ast=None, safe_parsing=True):
"""Wrapper to instantiate and use a Selector using node names."""
self.nodes = nodes
self.node_cls = node_cls
self.nodes = nodes or []
self.ast = ast
self.safe_parsing = safe_parsing

self.ParseError = getattr(self.ast, "ParseError", None)

def __call__(self, name, index, node, *args, **kwargs):
# TODO: gentle error handling
ast_cls = self.nodes[name]
if name in self.nodes:
ast_cls = self.nodes[name]
strict_selector = True
else:
ast_cls = self.node_cls
strict_selector = False

selector = Selector(ast_cls, *args, **kwargs)
selector = Selector(
ast_cls, src_name=name, strict=strict_selector, *args, **kwargs
)
selector.visit(node, head=True)

return selector.out[index]
Expand Down Expand Up @@ -81,7 +98,7 @@ def from_module(cls, mod):
for k, v in vars(mod).items()
if (inspect.isclass(v) and issubclass(v, mod.AstNode))
}
dispatcher = cls(ast_nodes, ast=mod)
dispatcher = cls(mod.AstNode, nodes=ast_nodes, ast=mod)
return dispatcher


Expand Down
6 changes: 3 additions & 3 deletions tests/test_check_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,15 @@ def state():
pre_exercise_code = "NA",
student_result = "", solution_result = "",
student_conn = None, solution_conn = None,
ast_dispatcher = Dispatcher(DUMMY_NODES, ParseHey())
ast_dispatcher = Dispatcher(ast.AST, DUMMY_NODES, ParseHey())
)

def test_initial_state():
State(student_code = {'script.py': '1'}, solution_code = {'script.py': '1'},
reporter = Reporter(), pre_exercise_code = "",
student_result = "", solution_result = "",
student_conn = None, solution_conn = None,
ast_dispatcher = Dispatcher(DUMMY_NODES, ParseHey()))
ast_dispatcher = Dispatcher(ast.AST, DUMMY_NODES, ParseHey()))

def test_check_file_use_fs(state, tf):
state.solution_code = { tf.name: '3 + 3' }
Expand Down Expand Up @@ -84,7 +84,7 @@ def code_state():
pre_exercise_code = "NA",
student_result = "", solution_result = "",
student_conn = None, solution_conn = None,
ast_dispatcher = Dispatcher(DUMMY_NODES, ParseHey())
ast_dispatcher = Dispatcher(ast.AST, DUMMY_NODES, ParseHey())
)

def test_check_file(code_state):
Expand Down

0 comments on commit 6dff876

Please sign in to comment.