Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Google sync #1538

Merged
merged 2 commits into from
Nov 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions pytype/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -1448,13 +1448,13 @@ def match_posargs_count(self, stack, cls, posargs, match_args, details=None):

@_error_name("incomplete-match")
def incomplete_match(self, stack, line, cases, details=None):
cases = ", ".join(cases)
msg = f"The enum match is missing the following cases: {cases}"
cases = ", ".join(str(x) for x in cases)
msg = f"The match is missing the following cases: {cases}"
self.error(stack, msg, details=details, lineno=line)

@_error_name("redundant-match")
def redundant_match(self, stack, case, details=None):
msg = f"This enum case has already been covered: {case}."
msg = f"This case has already been covered: {case}."
self.error(stack, msg, details=details)

@_error_name("paramspec-error")
Expand Down
3 changes: 2 additions & 1 deletion pytype/load_pytd.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,8 @@ def resolve_external_types(self, mod_ast, module_map, aliases, *, mod_name):
mod_ast = mod_ast.Visit(visitors.LookupExternalTypes(
module_map, self_name=name, module_alias_map=all_aliases))
except KeyError as e:
raise BadDependencyError(str(e), name) from e
key = "".join(str(arg) for arg in e.args)
raise BadDependencyError(key, name) from e
return mod_ast

def resolve_module_alias(self, name, *, lookup_ast=None,
Expand Down
2 changes: 1 addition & 1 deletion pytype/load_pytd_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,7 +556,7 @@ def f(x: Ellipsis): ...

def test_import_typevar(self):
# Regression test for the loader crashing with a
# ""Duplicate top level items: 'T', 'T'" error.
# "Duplicate top level items: 'T', 'T'" error.
self._import(a="""
from typing import TypeVar
T = TypeVar('T')
Expand Down
120 changes: 115 additions & 5 deletions pytype/pattern_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,29 @@ def invalidate(self):
self.is_valid = False


class _LiteralTracker:
"""Track literal cases for exhaustiveness."""

def __init__(self, match_var):
self.match_var = match_var
self.members = [x.pyval for x in match_var.data]
self.uncovered = set(self.members)
# The last case in an exhaustive match always succeeds.
self.implicit_default = None
# Invalidate the tracker if we run into code that is not a simple match
# against a single value.
self.is_valid = True

def cover(self, literal_case):
self.uncovered.discard(literal_case.pyval)

def cover_all(self):
self.uncovered = set()

def invalidate(self):
self.is_valid = False


class _TypeTracker:
"""Track class type cases for exhaustiveness."""

Expand Down Expand Up @@ -149,6 +172,7 @@ class BranchTracker:
def __init__(self, ast_matches, ctx):
self.matches = _Matches(ast_matches)
self._enum_tracker = {}
self._literal_tracker = {}
self._type_tracker: Dict[int, Dict[int, _TypeTracker]] = (
collections.defaultdict(dict))
self._match_types: Dict[int, Set[_MatchTypes]] = (
Expand Down Expand Up @@ -192,6 +216,38 @@ def _get_enum_tracker(
return None
return enum_tracker

def _get_literal_tracker(
self, match_var: cfg.Variable, match_line: Optional[int]
) -> Optional[_LiteralTracker]:
"""Get the literal tracker for a match line."""
if match_line is None:
return None
if match_line not in self._literal_tracker:
self._add_new_literal_match(match_var, match_line)
literal_tracker = self._literal_tracker[match_line]
if (match_var.id != literal_tracker.match_var.id or
self._match_types[match_line] != {_MatchTypes.CMP}):
# We are matching a tuple or structure with different literals in it.
literal_tracker.invalidate()
return None
return literal_tracker

def _is_literal_match(
self, match_val: abstract.BaseValue, case_val: abstract.BaseValue
) -> bool:
if not (isinstance(match_val, abstract.Instance) and
isinstance(match_val.cls, abstract.Class) and
match_val.cls.is_enum):
return False
if not (isinstance(case_val, abstract.Instance) and
case_val.cls == match_val.cls):
return False
return True

def _add_new_literal_match(self, match_var: cfg.Variable, match_line: int):
self._literal_tracker[match_line] = _LiteralTracker(match_var)
self._active_ends.add(self.matches.start_to_end[match_line])

def _add_new_type_match(self, match_var: cfg.Variable, match_line: int):
self._type_tracker[match_line][match_var.id] = _TypeTracker(
match_var, self.ctx)
Expand Down Expand Up @@ -286,9 +342,50 @@ def _add_enum_branch(
# This has already been covered, and will never succeed.
return False

def add_cmp_branch(self, op: opcodes.Opcode, match_var: cfg.Variable,
def _add_literal_branch(
self,
op: opcodes.Opcode,
match_var: cfg.Variable,
case_val: abstract.SimpleValue
) -> Optional[bool]:
"""Add a case branch for a literal match to the tracker."""
if op in self._seen_opcodes:
match_line = self.matches.match_cases.get(op.line)
tracker = self._get_literal_tracker(match_var, match_line)
if not tracker:
return None
if (tracker.implicit_default and case_val and
case_val.cls == tracker.implicit_default.cls):
return True
else:
return None
else:
self._seen_opcodes.add(op)
match_line = self.matches.match_cases.get(op.line)
tracker = self._get_literal_tracker(match_var, match_line)
if not tracker or not tracker.is_valid:
return None
if not isinstance(case_val, abstract.ConcreteValue):
tracker.invalidate()
return None
if case_val.pyval in tracker.uncovered:
tracker.cover(case_val)
if tracker.uncovered:
return None
else:
# This is the last remaining case, and will always succeed.
tracker.implicit_default = case_val
return True
else:
# This has already been covered, and will never succeed.
return False

def add_cmp_branch(self, op: opcodes.OpcodeWithArg, match_var: cfg.Variable,
case_var: cfg.Variable) -> _MatchSuccessType:
"""Add a compare-based match case branch to the tracker."""
if op.arg != slots.CMP_EQ:
return None

try:
case_val = abstract_utils.get_atomic_value(case_var)
except abstract_utils.ConversionError:
Expand All @@ -300,14 +397,19 @@ def add_cmp_branch(self, op: opcodes.Opcode, match_var: cfg.Variable,
# because even an ambigious cmp match will require the type to be set within
# the case branch).
op = cast(opcodes.OpcodeWithArg, op)
if (op.arg == slots.CMP_EQ and op.line in self.matches.match_cases):
if op.line in self.matches.match_cases:
if tracker := self.get_current_type_tracker(op, match_var):
tracker.cover_from_cmp(op.line, case_var)

if all(isinstance(x, abstract.ConcreteValue) for x in match_var.data):
# We are matching a union of concrete values, i.e. a Literal
return self._add_literal_branch(op, match_var, case_val)

try:
match_val = abstract_utils.get_atomic_value(match_var)
except abstract_utils.ConversionError:
return None

if self._is_enum_match(match_val, case_val):
return self._add_enum_branch(op, match_val, case_val)
else:
Expand All @@ -325,9 +427,12 @@ def add_default_branch(self, op: opcodes.Opcode) -> _MatchSuccessType:
match_line = self.matches.match_cases.get(op.line)
if match_line is None:
return None
if match_line not in self._enum_tracker:
if match_line in self._enum_tracker:
self._enum_tracker[match_line].cover_all()
elif match_line in self._literal_tracker:
self._literal_tracker[match_line].cover_all()
else:
return None
self._enum_tracker[match_line].cover_all()
return True

def check_ending(self,
Expand All @@ -349,7 +454,12 @@ def check_ending(self,
ret = []
for i in done:
for start in self.matches.end_to_starts[i]:
tracker = self._enum_tracker[start]
if start in self._enum_tracker:
tracker = self._enum_tracker[start]
elif start in self._literal_tracker:
tracker = self._literal_tracker[start]
else:
assert False
if tracker.is_valid:
if uncovered := tracker.uncovered:
ret.append((start, uncovered))
Expand Down
Loading
Loading