From 8c4da4abf4e6aa6ce7a6a5646bc7f545d1d23dc2 Mon Sep 17 00:00:00 2001 From: "Terence D. Honles" Date: Tue, 23 Mar 2021 13:34:22 -0700 Subject: [PATCH] track `if typing.TYPE_CHECKING` to warn about non runtime bindings When importing or defining values in ``if typing.TYPE_CHECKING`` blocks the bound names will not be available at runtime and may cause errors when used in the following way:: import typing if typing.TYPE_CHECKING: from module import Type # some slow import or circular reference def method(value) -> Type: # the import is needed by the type checker assert isinstance(value, Type) # this is a runtime error This change allows pyflakes to track what names are bound for runtime use, and allows it to warn when a non runtime name is used in a runtime context. --- pyflakes/checker.py | 115 ++++++++++++++++++------- pyflakes/test/test_type_annotations.py | 66 ++++++++++++++ 2 files changed, 151 insertions(+), 30 deletions(-) diff --git a/pyflakes/checker.py b/pyflakes/checker.py index 754ab30c..a5a13c0e 100644 --- a/pyflakes/checker.py +++ b/pyflakes/checker.py @@ -226,10 +226,11 @@ class Binding: the node that this binding was last used. """ - def __init__(self, name, source): + def __init__(self, name, source, runtime=True): self.name = name self.source = source self.used = False + self.runtime = runtime def __str__(self): return self.name @@ -260,8 +261,8 @@ def redefines(self, other): class Builtin(Definition): """A definition created for all Python builtins.""" - def __init__(self, name): - super().__init__(name, None) + def __init__(self, name, runtime=True): + super().__init__(name, None, runtime=runtime) def __repr__(self): return '<{} object {!r} at 0x{:x}>'.format( @@ -305,10 +306,10 @@ class Importation(Definition): @type fullName: C{str} """ - def __init__(self, name, source, full_name=None): + def __init__(self, name, source, full_name=None, runtime=True): self.fullName = full_name or name self.redefined = [] - super().__init__(name, source) + super().__init__(name, source, runtime=runtime) def redefines(self, other): if isinstance(other, SubmoduleImportation): @@ -353,11 +354,11 @@ class SubmoduleImportation(Importation): name is also the same, to avoid false positives. """ - def __init__(self, name, source): + def __init__(self, name, source, runtime=True): # A dot should only appear in the name when it is a submodule import assert '.' in name and (not source or isinstance(source, ast.Import)) package_name = name.split('.')[0] - super().__init__(package_name, source) + super().__init__(package_name, source, runtime=runtime) self.fullName = name def redefines(self, other): @@ -375,7 +376,8 @@ def source_statement(self): class ImportationFrom(Importation): - def __init__(self, name, source, module, real_name=None): + def __init__( + self, name, source, module, real_name=None, runtime=True): self.module = module self.real_name = real_name or name @@ -384,7 +386,7 @@ def __init__(self, name, source, module, real_name=None): else: full_name = module + '.' + self.real_name - super().__init__(name, source, full_name) + super().__init__(name, source, full_name, runtime=runtime) def __str__(self): """Return import full name with alias.""" @@ -404,8 +406,8 @@ def source_statement(self): class StarImportation(Importation): """A binding created by a 'from x import *' statement.""" - def __init__(self, name, source): - super().__init__('*', source) + def __init__(self, name, source, runtime=True): + super().__init__('*', source, runtime=runtime) # Each star importation needs a unique name, and # may not be the module name otherwise it will be deemed imported self.name = name + '.*' @@ -494,7 +496,7 @@ class ExportBinding(Binding): C{__all__} will not have an unused import warning reported for them. """ - def __init__(self, name, source, scope): + def __init__(self, name, source, scope, runtime=True): if '__all__' in scope and isinstance(source, ast.AugAssign): self.names = list(scope['__all__'].names) else: @@ -525,7 +527,7 @@ def _add_to_names(container): # If not list concatenation else: break - super().__init__(name, source) + super().__init__(name, source, runtime=runtime) class Scope(dict): @@ -732,6 +734,7 @@ class Checker: nodeDepth = 0 offset = None _in_annotation = AnnotationState.NONE + _in_type_check_guard = False builtIns = set(builtin_vars).union(_MAGIC_GLOBALS) _customBuiltIns = os.environ.get('PYFLAKES_BUILTINS') @@ -1000,9 +1003,11 @@ def addBinding(self, node, value): # then assume the rebound name is used as a global or within a loop value.used = self.scope[value.name].used - # don't treat annotations as assignments if there is an existing value - # in scope - if value.name not in self.scope or not isinstance(value, Annotation): + # always allow the first assignment or if not already a runtime value, + # but do not shadow an existing assignment with an annotation or non + # runtime value. + if (not existing or not existing.runtime or ( + not isinstance(value, Annotation) and value.runtime)): cur_scope_pos = -1 # As per PEP 572, use scope in which outermost generator is defined while ( @@ -1073,12 +1078,18 @@ def handleNodeLoad(self, node, parent): self.report(messages.InvalidPrintSyntax, node) try: - scope[name].used = (self.scope, node) + n = scope[name] + if (not n.runtime and not ( + self._in_type_check_guard + or self._in_annotation)): + self.report(messages.UndefinedName, node, name) + return + + n.used = (self.scope, node) # if the name of SubImportation is same as # alias of other Importation and the alias # is used, SubImportation also should be marked as used. - n = scope[name] if isinstance(n, Importation) and n._has_alias(): try: scope[n.fullName].used = (self.scope, node) @@ -1143,12 +1154,13 @@ def handleNodeStore(self, node): break parent_stmt = self.getParent(node) + runtime = not self._in_type_check_guard if isinstance(parent_stmt, ast.AnnAssign) and parent_stmt.value is None: binding = Annotation(name, node) elif isinstance(parent_stmt, (FOR_TYPES, ast.comprehension)) or ( parent_stmt != node._pyflakes_parent and not self.isLiteralTupleUnpacking(parent_stmt)): - binding = Binding(name, node) + binding = Binding(name, node, runtime=runtime) elif ( name == '__all__' and isinstance(self.scope, ModuleScope) and @@ -1157,11 +1169,12 @@ def handleNodeStore(self, node): (ast.Assign, ast.AugAssign, ast.AnnAssign) ) ): - binding = ExportBinding(name, node._pyflakes_parent, self.scope) + binding = ExportBinding( + name, node._pyflakes_parent, self.scope, runtime=runtime) elif isinstance(parent_stmt, ast.NamedExpr): - binding = NamedExprAssignment(name, node) + binding = NamedExprAssignment(name, node, runtime=runtime) else: - binding = Assignment(name, node) + binding = Assignment(name, node, runtime=runtime) self.addBinding(node, binding) def handleNodeDelete(self, node): @@ -1805,7 +1818,39 @@ def DICT(self, node): def IF(self, node): if isinstance(node.test, ast.Tuple) and node.test.elts != []: self.report(messages.IfTuple, node) - self.handleChildren(node) + + self.handleNode(node.test, node) + + # check if the body/orelse should be handled specially because it is + # a if TYPE_CHECKING guard. + test = node.test + reverse = False + if isinstance(test, ast.UnaryOp) and isinstance(test.op, ast.Not): + test = test.operand + reverse = True + + type_checking = _is_typing(test, 'TYPE_CHECKING', self.scopeStack) + orig = self._in_type_check_guard + + # normalize body and orelse to a list + body, orelse = ( + i if isinstance(i, list) else [i] + for i in (node.body, node.orelse)) + + # set the guard and handle the body + if type_checking and not reverse: + self._in_type_check_guard = True + + for n in body: + self.handleNode(n, node) + + # set the guard and handle the orelse + if type_checking: + self._in_type_check_guard = True if reverse else orig + + for n in orelse: + self.handleNode(n, node) + self._in_type_check_guard = orig IFEXP = IF @@ -1920,7 +1965,10 @@ def FUNCTIONDEF(self, node): with self._type_param_scope(node): self.LAMBDA(node) - self.addBinding(node, FunctionDefinition(node.name, node)) + self.addBinding( + node, + FunctionDefinition( + node.name, node, runtime=not self._in_type_check_guard)) # doctest does not process doctest within a doctest, # or in nested functions. if (self.withDoctest and @@ -2005,7 +2053,10 @@ def CLASSDEF(self, node): for stmt in node.body: self.handleNode(stmt, node) - self.addBinding(node, ClassDefinition(node.name, node)) + self.addBinding( + node, + ClassDefinition( + node.name, node, runtime=not self._in_type_check_guard)) def AUGASSIGN(self, node): self.handleNodeLoad(node.target, node) @@ -2038,12 +2089,15 @@ def TUPLE(self, node): LIST = TUPLE def IMPORT(self, node): + runtime = not self._in_type_check_guard for alias in node.names: if '.' in alias.name and not alias.asname: - importation = SubmoduleImportation(alias.name, node) + importation = SubmoduleImportation( + alias.name, node, runtime=runtime) else: name = alias.asname or alias.name - importation = Importation(name, node, alias.name) + importation = Importation( + name, node, alias.name, runtime=runtime) self.addBinding(node, importation) def IMPORTFROM(self, node): @@ -2055,6 +2109,7 @@ def IMPORTFROM(self, node): module = ('.' * node.level) + (node.module or '') + runtime = not self._in_type_check_guard for alias in node.names: name = alias.asname or alias.name if node.module == '__future__': @@ -2072,10 +2127,10 @@ def IMPORTFROM(self, node): self.scope.importStarred = True self.report(messages.ImportStarUsed, node, module) - importation = StarImportation(module, node) + importation = StarImportation(module, node, runtime=runtime) else: - importation = ImportationFrom(name, node, - module, alias.name) + importation = ImportationFrom( + name, node, module, alias.name, runtime=runtime) self.addBinding(node, importation) def TRY(self, node): diff --git a/pyflakes/test/test_type_annotations.py b/pyflakes/test/test_type_annotations.py index 4c8b998f..cb339d8c 100644 --- a/pyflakes/test/test_type_annotations.py +++ b/pyflakes/test/test_type_annotations.py @@ -645,6 +645,55 @@ def f() -> T: pass """) + def test_typing_guard_import(self): + # T is imported for runtime use + self.flakes(""" + from typing import TYPE_CHECKING + + if TYPE_CHECKING: + from t import T + + def f(x) -> T: + from t import T + + assert isinstance(x, T) + return x + """) + # T is defined at runtime in one side of the if/else block + self.flakes(""" + from typing import TYPE_CHECKING, Union + + if TYPE_CHECKING: + from t import T + else: + T = object + + if not TYPE_CHECKING: + U = object + else: + from t import U + + def f(x) -> Union[T, U]: + assert isinstance(x, (T, U)) + return x + """) + + def test_typing_guard_import_runtime_error(self): + # T and U are not bound for runtime use + self.flakes(""" + from typing import TYPE_CHECKING, Union + + if TYPE_CHECKING: + from t import T + + class U: + pass + + def f(x) -> Union[T, U]: + assert isinstance(x, (T, U)) + return x + """, m.UndefinedName, m.UndefinedName) + def test_typing_guard_for_protocol(self): self.flakes(""" from typing import TYPE_CHECKING @@ -659,6 +708,23 @@ def f() -> int: pass """) + def test_typing_guard_with_elif_branch(self): + # This test will not raise an error even though Protocol is not + # defined outside TYPE_CHECKING because Pyflakes does not do case + # analysis. + self.flakes(""" + from typing import TYPE_CHECKING + if TYPE_CHECKING: + from typing import Protocol + elif False: + Protocol = object + else: + pass + class C(Protocol): + def f(): # type: () -> int + pass + """) + def test_typednames_correct_forward_ref(self): self.flakes(""" from typing import TypedDict, List, NamedTuple