diff --git a/src/slipcover/branch.py b/src/slipcover/branch.py index 111ee86..6fea877 100644 --- a/src/slipcover/branch.py +++ b/src/slipcover/branch.py @@ -87,9 +87,11 @@ def visit_Match(self, node: ast.Match) -> ast.Match: for case in node.cases: case.body = self._mark_branch(node.lineno, case.body[0].lineno) + case.body - has_wildcard = isinstance(node.cases[-1].pattern, ast.MatchAs) and \ - node.cases[-1].pattern.pattern is None + last_pattern = case.pattern # case is node.cases[-1] + while isinstance(last_pattern, ast.MatchOr): + last_pattern = last_pattern.patterns[-1] + has_wildcard = case.guard is None and isinstance(last_pattern, ast.MatchAs) and last_pattern.pattern is None if not has_wildcard: to_line = node.next_node.lineno if node.next_node else 0 # exit node.cases.append(ast.match_case(ast.MatchAs(), diff --git a/src/slipcover/version.py b/src/slipcover/version.py index b19b12e..f871089 100644 --- a/src/slipcover/version.py +++ b/src/slipcover/version.py @@ -1 +1 @@ -__version__ = "1.0.14" +__version__ = "1.0.15" diff --git a/tests/test_branch.py b/tests/test_branch.py index 87952aa..51fd969 100644 --- a/tests/test_branch.py +++ b/tests/test_branch.py @@ -547,6 +547,22 @@ def test_match_case_with_false_guard(): assert [(3,7)] == g[br.BRANCH_NAME] +@pytest.mark.skipif(PYTHON_VERSION < (3,10), reason="New in 3.10") +def test_match_case_with_guard_isnt_wildcard(): + t = ast_parse(""" + def fun(v): + match v: + case _ if v > 0: + print("not default") + """) + + + t = br.preinstrument(t) + check_locations(t) + code = compile(t, "foo", "exec") + assert [(2,0), (2,4)] == get_branches(code) + + @pytest.mark.skipif(PYTHON_VERSION < (3,10), reason="New in 3.10") def test_match_branch_to_exit(): t = ast_parse(""" @@ -693,6 +709,39 @@ def test_branch_after_case_with_next(): assert [(2,4), (4,9)] == g[br.BRANCH_NAME] +@pytest.mark.skipif(PYTHON_VERSION < (3,10), reason="New in 3.10") +def test_match_wildcard_in_match_or(): + # Thanks to Ned Batchelder for this test case + t = ast_parse(f""" + def absurd(x): + match x: + case (3 | 99 | (999 | _)): + print("default") + absurd(5) + """) + + t = br.preinstrument(t) + check_locations(t) + code = compile(t, "foo", "exec") + assert [(2,4)] == get_branches(code) + + +@pytest.mark.skipif(PYTHON_VERSION < (3,10), reason="New in 3.10") +def test_match_capture(): + t = ast_parse(f""" + def capture(x): + match x: + case y: + print("default") + capture(5) + """) + + t = br.preinstrument(t) + check_locations(t) + code = compile(t, "foo", "exec") + assert [(2,4)] == get_branches(code) + + @pytest.mark.parametrize("star", ['', '*'] if PYTHON_VERSION >= (3,11) else ['']) def test_try_except(star): t = ast_parse(f""" @@ -781,3 +830,4 @@ def foo(x): check_locations(t) code = compile(t, "foo", "exec") assert [(4,5), (4,10), (7,8), (7,13), (10,11), (10,13)] == get_branches(code) +