From 6dff876953cbde2f63b3ad57c7cafe764f25b379 Mon Sep 17 00:00:00 2001 From: Jeroen Hermans Date: Thu, 7 Mar 2019 20:06:24 +0100 Subject: [PATCH] Make dispatcher work with dynamic ast nodes Update dispatcher tests sqlwhat already has tests for this functionality on a SQL grammar --- protowhat/selectors.py | 37 +++++++++++++++++++++++++++---------- tests/test_check_files.py | 6 +++--- 2 files changed, 30 insertions(+), 13 deletions(-) diff --git a/protowhat/selectors.py b/protowhat/selectors.py index 19effb1..133098a 100644 --- a/protowhat/selectors.py +++ b/protowhat/selectors.py @@ -4,8 +4,10 @@ class Selector(NodeVisitor): - def __init__(self, src, priority=None): + def __init__(self, src, src_name=None, strict=True, priority=None): self.src = src + self.src_name = src_name + self.strict = strict self.priority = src._priority if priority is None else priority self.out = [] @@ -23,29 +25,44 @@ def visit_list(self, lst): self.visit(item) def is_match(self, node): - if type(node) is self.src: - return True + if self.strict: + if type(node) is self.src: + return True + else: + return False else: - return False + if isinstance(node, self.src) and ( + self.src_name is None or self.src_name == node.__class__.__name__ + ): + return True + else: + return False def has_priority_over(self, node): return self.priority > node._priority class Dispatcher: - def __init__(self, nodes, ast=None, safe_parsing=True): + def __init__(self, node_cls, nodes=None, ast=None, safe_parsing=True): """Wrapper to instantiate and use a Selector using node names.""" - self.nodes = nodes + self.node_cls = node_cls + self.nodes = nodes or [] self.ast = ast self.safe_parsing = safe_parsing self.ParseError = getattr(self.ast, "ParseError", None) def __call__(self, name, index, node, *args, **kwargs): - # TODO: gentle error handling - ast_cls = self.nodes[name] + if name in self.nodes: + ast_cls = self.nodes[name] + strict_selector = True + else: + ast_cls = self.node_cls + strict_selector = False - selector = Selector(ast_cls, *args, **kwargs) + selector = Selector( + ast_cls, src_name=name, strict=strict_selector, *args, **kwargs + ) selector.visit(node, head=True) return selector.out[index] @@ -81,7 +98,7 @@ def from_module(cls, mod): for k, v in vars(mod).items() if (inspect.isclass(v) and issubclass(v, mod.AstNode)) } - dispatcher = cls(ast_nodes, ast=mod) + dispatcher = cls(mod.AstNode, nodes=ast_nodes, ast=mod) return dispatcher diff --git a/tests/test_check_files.py b/tests/test_check_files.py index afe96b1..8e60216 100644 --- a/tests/test_check_files.py +++ b/tests/test_check_files.py @@ -38,7 +38,7 @@ def state(): pre_exercise_code = "NA", student_result = "", solution_result = "", student_conn = None, solution_conn = None, - ast_dispatcher = Dispatcher(DUMMY_NODES, ParseHey()) + ast_dispatcher = Dispatcher(ast.AST, DUMMY_NODES, ParseHey()) ) def test_initial_state(): @@ -46,7 +46,7 @@ def test_initial_state(): reporter = Reporter(), pre_exercise_code = "", student_result = "", solution_result = "", student_conn = None, solution_conn = None, - ast_dispatcher = Dispatcher(DUMMY_NODES, ParseHey())) + ast_dispatcher = Dispatcher(ast.AST, DUMMY_NODES, ParseHey())) def test_check_file_use_fs(state, tf): state.solution_code = { tf.name: '3 + 3' } @@ -84,7 +84,7 @@ def code_state(): pre_exercise_code = "NA", student_result = "", solution_result = "", student_conn = None, solution_conn = None, - ast_dispatcher = Dispatcher(DUMMY_NODES, ParseHey()) + ast_dispatcher = Dispatcher(ast.AST, DUMMY_NODES, ParseHey()) ) def test_check_file(code_state):