Skip to content

Commit

Permalink
Fix pattern matching with case None branches.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 586506370
  • Loading branch information
martindemello authored and rchen152 committed Nov 30, 2023
1 parent f11e941 commit bedd6f9
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 5 deletions.
36 changes: 33 additions & 3 deletions pytype/pattern_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,14 @@ def cover_from_cmp(self, line, case_var):
# current case so that instantiate_case_var can retrieve it.
for d in case_var.data:
self.case_types[line].add(d.cls)
if isinstance(d, abstract.ConcreteValue) and d.pyval is None:
# Need to special-case `case None` since it's compiled differently.
self.uncovered.discard(d.cls)

def cover_from_none(self, line):
cls = self.ctx.convert.none_type
self.case_types[line].add(cls)
self.uncovered.discard(cls)

@property
def complete(self):
Expand Down Expand Up @@ -380,10 +388,26 @@ def _add_literal_branch(
# This has already been covered, and will never succeed.
return False

def add_cmp_branch(self, op: opcodes.OpcodeWithArg, match_var: cfg.Variable,
case_var: cfg.Variable) -> _MatchSuccessType:
def add_none_branch(self, op: opcodes.Opcode, match_var: cfg.Variable):
if op.line in self.matches.match_cases:
if tracker := self.get_current_type_tracker(op, match_var):
tracker.cover_from_none(op.line)
if tracker.uncovered:
return None
else:
# This is the last remaining case, and will always succeed.
tracker.implicit_default = self.ctx.convert.none_type
return True

def add_cmp_branch(
self,
op: opcodes.OpcodeWithArg,
cmp_type: int,
match_var: cfg.Variable,
case_var: cfg.Variable
) -> _MatchSuccessType:
"""Add a compare-based match case branch to the tracker."""
if op.arg != slots.CMP_EQ:
if cmp_type not in (slots.CMP_EQ, slots.CMP_IS):
return None

try:
Expand All @@ -400,6 +424,12 @@ def add_cmp_branch(self, op: opcodes.OpcodeWithArg, match_var: cfg.Variable,
if op.line in self.matches.match_cases:
if tracker := self.get_current_type_tracker(op, match_var):
tracker.cover_from_cmp(op.line, case_var)
if tracker.uncovered:
return None
else:
# This is the last remaining case, and will always succeed.
tracker.implicit_default = case_val
return True

if all(isinstance(x, abstract.ConcreteValue) for x in match_var.data):
# We are matching a union of concrete values, i.e. a Literal
Expand Down
16 changes: 16 additions & 0 deletions pytype/tests/test_pattern_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,22 @@ def f(x: A | B | C | D):
assert_type(x, D)
""")

def test_type_narrowing_none(self):
self.Check("""
class A: pass
class B: pass
def f(x: A | B | None):
match x:
case A() | B():
assert_type(x, A | B)
case None:
assert_type(x, None)
case _:
# This branch will not be entered
assert_type(1, str)
""")

def test_type_narrowing_mixed(self):
self.Check("""
class A: pass
Expand Down
15 changes: 13 additions & 2 deletions pytype/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,13 +214,14 @@ def _is_match_case_op(self, op):
is_match = opname.startswith("MATCH_")
# Opcodes generated by various matches against constant literals.
is_cmp_match = opname in ("COMPARE_OP", "IS_OP", "CONTAINS_OP")
is_none_match = opname in ("POP_JUMP_FORWARD_IF_NOT_NONE",)
# `case _` generates a NOP if the match is not captured. If it is, we need
# to look for the opcode before the STORE_FAST, since the match itself does
# not generate any specific opcode, just stack manipulations.
is_default_match = opname == "NOP" or (
isinstance(op.next, opcodes.STORE_FAST) and
op.line in self._branch_tracker.matches.defaults)
return is_match or is_cmp_match or is_default_match
return is_match or is_cmp_match or is_default_match or is_none_match

def _handle_match_case(self, state, op):
"""Track type narrowing and default cases in a match statement."""
Expand Down Expand Up @@ -1829,7 +1830,7 @@ def _compare_op(self, state, op_arg, op):
"""Pops and compares the top two stack values and pushes a boolean."""
state, (x, y) = state.popn(2)
self._branch_tracker.register_match_type(op)
match_enum = self._branch_tracker.add_cmp_branch(op, x, y)
match_enum = self._branch_tracker.add_cmp_branch(op, op_arg, x, y)
if match_enum is not None:
# The match always succeeds/fails.
ret = self.ctx.convert.bool_values[match_enum].to_variable(state.node)
Expand Down Expand Up @@ -3302,6 +3303,16 @@ def byte_SEND(self, state, op):
return state.push(generator).push(yield_var)

def byte_POP_JUMP_FORWARD_IF_NOT_NONE(self, state, op):
# Check if this is a `case None` statement (3.11+ compiles it directly to a
# conditional jump rather than a compare and then jump).
self._branch_tracker.register_match_type(op)
match_none = self._branch_tracker.add_none_branch(op, state.top())
if match_none is True: # pylint: disable=g-bool-id-comparison
# This always fails due to earlier pattern matches, so replace the top of
# the stack with a None to ensure we do not jump.
state = state.pop_and_discard()
value = self.ctx.convert.none.to_variable(state.node)
state = state.push(value)
return vm_utils.jump_if(state, op, self.ctx,
jump_if_val=frame_state.NOT_NONE,
pop=vm_utils.PopBehavior.ALWAYS)
Expand Down

0 comments on commit bedd6f9

Please sign in to comment.