From 3609469f5a9601e41ea878efc2227e7378f61237 Mon Sep 17 00:00:00 2001 From: Colin Atkinson Date: Fri, 17 Dec 2021 01:51:51 -0500 Subject: [PATCH 1/2] Treat NewTypes like normal subclasses NewTypes are assumed not to inherit any members from their base classes. This results in incorrect inference results. Avoid this by changing the transformation for NewTypes to treat them like any other subclass. https://github.com/PyCQA/pylint/issues/3162 https://github.com/PyCQA/pylint/issues/2296 --- ChangeLog | 5 + astroid/brain/brain_typing.py | 93 ++++++++++++++--- tests/unittest_brain.py | 186 +++++++++++++++++++++++++++++++++- 3 files changed, 269 insertions(+), 15 deletions(-) diff --git a/ChangeLog b/ChangeLog index c3b3e6b90b..432baef35f 100644 --- a/ChangeLog +++ b/ChangeLog @@ -55,6 +55,11 @@ Release date: TBA * Fix test for Python ``3.11``. In some instances ``err.__traceback__`` will be uninferable now. +* Treat ``typing.NewType()`` values as normal subclasses. + + Closes PyCQA/pylint#2296 + Closes PyCQA/pylint#3162 + What's New in astroid 2.11.6? ============================= Release date: TBA diff --git a/astroid/brain/brain_typing.py b/astroid/brain/brain_typing.py index 807ba96e6e..bda4198f21 100644 --- a/astroid/brain/brain_typing.py +++ b/astroid/brain/brain_typing.py @@ -10,7 +10,7 @@ from collections.abc import Iterator from functools import partial -from astroid import context, extract_node, inference_tip +from astroid import context, extract_node, inference_tip, nodes from astroid.builder import _extract_single_node from astroid.const import PY38_PLUS, PY39_PLUS from astroid.exceptions import ( @@ -35,8 +35,6 @@ from astroid.util import Uninferable TYPING_NAMEDTUPLE_BASENAMES = {"NamedTuple", "typing.NamedTuple"} -TYPING_TYPEVARS = {"TypeVar", "NewType"} -TYPING_TYPEVARS_QUALIFIED = {"typing.TypeVar", "typing.NewType"} TYPING_TYPE_TEMPLATE = """ class Meta(type): def __getitem__(self, item): @@ -49,6 +47,13 @@ def __args__(self): class {0}(metaclass=Meta): pass """ +# PEP484 suggests NewType is equivalent to this for typing purposes +# https://www.python.org/dev/peps/pep-0484/#newtype-helper-function +TYPING_NEWTYPE_TEMPLATE = """ +class {derived}({base}): + def __init__(self, val: {base}) -> None: + ... +""" TYPING_MEMBERS = set(getattr(typing, "__all__", [])) TYPING_ALIAS = frozenset( @@ -103,24 +108,33 @@ def __class_getitem__(cls, item): """ -def looks_like_typing_typevar_or_newtype(node): +def looks_like_typing_typevar(node: nodes.Call) -> bool: + func = node.func + if isinstance(func, Attribute): + return func.attrname == "TypeVar" + if isinstance(func, Name): + return func.name == "TypeVar" + return False + + +def looks_like_typing_newtype(node: nodes.Call) -> bool: func = node.func if isinstance(func, Attribute): - return func.attrname in TYPING_TYPEVARS + return func.attrname == "NewType" if isinstance(func, Name): - return func.name in TYPING_TYPEVARS + return func.name == "NewType" return False -def infer_typing_typevar_or_newtype(node, context_itton=None): - """Infer a typing.TypeVar(...) or typing.NewType(...) call""" +def infer_typing_typevar( + node: nodes.Call, ctx: context.InferenceContext | None = None +) -> Iterator[nodes.ClassDef]: + """Infer a typing.TypeVar(...) call""" try: - func = next(node.func.infer(context=context_itton)) + next(node.func.infer(context=ctx)) except (InferenceError, StopIteration) as exc: raise UseInferenceDefault from exc - if func.qname() not in TYPING_TYPEVARS_QUALIFIED: - raise UseInferenceDefault if not node.args: raise UseInferenceDefault # Cannot infer from a dynamic class name (f-string) @@ -129,7 +143,53 @@ def infer_typing_typevar_or_newtype(node, context_itton=None): typename = node.args[0].as_string().strip("'") node = extract_node(TYPING_TYPE_TEMPLATE.format(typename)) - return node.infer(context=context_itton) + return node.infer(context=ctx) + + +def infer_typing_newtype( + node: nodes.Call, ctx: context.InferenceContext | None = None +) -> Iterator[nodes.ClassDef]: + """Infer a typing.NewType(...) call""" + try: + next(node.func.infer(context=ctx)) + except (InferenceError, StopIteration) as exc: + raise UseInferenceDefault from exc + + if len(node.args) != 2: + raise UseInferenceDefault + + # Cannot infer from a dynamic class name (f-string) + if isinstance(node.args[0], JoinedStr) or isinstance(node.args[1], JoinedStr): + raise UseInferenceDefault + + derived, base = node.args + derived_name = derived.as_string().strip("'") + base_name = base.as_string().strip("'") + + new_node: ClassDef = extract_node( + TYPING_NEWTYPE_TEMPLATE.format(derived=derived_name, base=base_name) + ) + new_node.parent = node.parent + + # Base type arg is a normal reference, so no need to do special lookups + if not isinstance(base, nodes.Const): + new_node.postinit( + bases=[base], body=new_node.body, decorators=new_node.decorators + ) + + # If the base type is given as a string (e.g. for a forward reference), + # make a naive attempt to find the corresponding node. + # Note that this will not work with imported types. + if isinstance(base, nodes.Const) and isinstance(base.value, str): + _, resolved_base = node.frame().lookup(base_name) + if resolved_base: + new_node.postinit( + bases=[resolved_base[0]], + body=new_node.body, + decorators=new_node.decorators, + ) + + return new_node.infer(context=ctx) def _looks_like_typing_subscript(node): @@ -403,8 +463,13 @@ def infer_typing_cast( AstroidManager().register_transform( Call, - inference_tip(infer_typing_typevar_or_newtype), - looks_like_typing_typevar_or_newtype, + inference_tip(infer_typing_typevar), + looks_like_typing_typevar, +) +AstroidManager().register_transform( + Call, + inference_tip(infer_typing_newtype), + looks_like_typing_newtype, ) AstroidManager().register_transform( Subscript, inference_tip(infer_typing_attr), _looks_like_typing_subscript diff --git a/tests/unittest_brain.py b/tests/unittest_brain.py index d86c273f32..0a6b83b45b 100644 --- a/tests/unittest_brain.py +++ b/tests/unittest_brain.py @@ -1640,6 +1640,26 @@ def test_typing_types(self) -> None: inferred = next(node.infer()) self.assertIsInstance(inferred, nodes.ClassDef, node.as_string()) + def test_typing_typevar_bad_args(self) -> None: + ast_nodes = builder.extract_node( + """ + from typing import TypeVar + + T = TypeVar() + T #@ + + U = TypeVar(f"U") + U #@ + """ + ) + assert isinstance(ast_nodes, list) + + no_args_node = ast_nodes[0] + assert list(no_args_node.infer()) == [util.Uninferable] + + fstr_node = ast_nodes[1] + assert list(fstr_node.infer()) == [util.Uninferable] + def test_typing_type_without_tip(self): """Regression test for https://github.com/PyCQA/pylint/issues/5770""" node = builder.extract_node( @@ -1651,7 +1671,171 @@ def make_new_type(t): """ ) with self.assertRaises(UseInferenceDefault): - astroid.brain.brain_typing.infer_typing_typevar_or_newtype(node.value) + astroid.brain.brain_typing.infer_typing_newtype(node.value) + + def test_typing_newtype_attrs(self) -> None: + ast_nodes = builder.extract_node( + """ + from typing import NewType + import decimal + from decimal import Decimal + + NewType("Foo", str) #@ + NewType("Bar", "int") #@ + NewType("Baz", Decimal) #@ + NewType("Qux", decimal.Decimal) #@ + """ + ) + assert isinstance(ast_nodes, list) + + # Base type given by reference + foo_node = ast_nodes[0] + + # Should be unambiguous + foo_inferred_all = list(foo_node.infer()) + assert len(foo_inferred_all) == 1 + + foo_inferred = foo_inferred_all[0] + assert isinstance(foo_inferred, astroid.ClassDef) + + # Check base type method is inferred by accessing one of its methods + foo_base_class_method = foo_inferred.getattr("endswith")[0] + assert isinstance(foo_base_class_method, astroid.FunctionDef) + assert foo_base_class_method.qname() == "builtins.str.endswith" + + # Base type given by string (i.e. "int") + bar_node = ast_nodes[1] + bar_inferred_all = list(bar_node.infer()) + assert len(bar_inferred_all) == 1 + bar_inferred = bar_inferred_all[0] + assert isinstance(bar_inferred, astroid.ClassDef) + + bar_base_class_method = bar_inferred.getattr("bit_length")[0] + assert isinstance(bar_base_class_method, astroid.FunctionDef) + assert bar_base_class_method.qname() == "builtins.int.bit_length" + + # Decimal may be reexported from an implementation-defined module. For + # example, in CPython 3.10 this is _decimal, but in PyPy 7.3 it's + # _pydecimal. So the expected qname needs to be grabbed dynamically. + decimal_quant_node = builder.extract_node( + """ + from decimal import Decimal + Decimal.quantize #@ + """ + ) + assert isinstance(decimal_quant_node, nodes.NodeNG) + + # Just grab the first result, since infer() may return values for both + # _decimal and _pydecimal + decimal_quant_qname = next(decimal_quant_node.infer()).qname() + + # Base type is from an "import from" + baz_node = ast_nodes[2] + baz_inferred_all = list(baz_node.infer()) + assert len(baz_inferred_all) == 1 + baz_inferred = baz_inferred_all[0] + assert isinstance(baz_inferred, astroid.ClassDef) + + baz_base_class_method = baz_inferred.getattr("quantize")[0] + assert isinstance(baz_base_class_method, astroid.FunctionDef) + assert decimal_quant_qname == baz_base_class_method.qname() + + # Base type is from an import + qux_node = ast_nodes[3] + qux_inferred_all = list(qux_node.infer()) + qux_inferred = qux_inferred_all[0] + assert isinstance(qux_inferred, astroid.ClassDef) + + qux_base_class_method = qux_inferred.getattr("quantize")[0] + assert isinstance(qux_base_class_method, astroid.FunctionDef) + assert decimal_quant_qname == qux_base_class_method.qname() + + def test_typing_newtype_bad_args(self) -> None: + ast_nodes = builder.extract_node( + """ + from typing import NewType + + NoArgs = NewType() + NoArgs #@ + + OneArg = NewType("OneArg") + OneArg #@ + + ThreeArgs = NewType("ThreeArgs", int, str) + ThreeArgs #@ + + DynamicArg = NewType(f"DynamicArg", int) + DynamicArg #@ + + DynamicBase = NewType("DynamicBase", f"int") + DynamicBase #@ + """ + ) + assert isinstance(ast_nodes, list) + + node: nodes.NodeNG + for node in ast_nodes: + assert list(node.infer()) == [util.Uninferable] + + def test_typing_newtype_user_defined(self) -> None: + ast_nodes = builder.extract_node( + """ + from typing import NewType + + class A: + def __init__(self, value: int): + self.value = value + + a = A(5) + a #@ + + B = NewType("B", A) + b = B(5) + b #@ + """ + ) + assert isinstance(ast_nodes, list) + + for node in ast_nodes: + self._verify_node_has_expected_attr(node) + + def test_typing_newtype_forward_reference(self) -> None: + # Similar to the test above, but using a forward reference for "A" + ast_nodes = builder.extract_node( + """ + from typing import NewType + + B = NewType("B", "A") + + class A: + def __init__(self, value: int): + self.value = value + + a = A(5) + a #@ + + b = B(5) + b #@ + """ + ) + assert isinstance(ast_nodes, list) + + for node in ast_nodes: + self._verify_node_has_expected_attr(node) + + def _verify_node_has_expected_attr(self, node: nodes.NodeNG) -> None: + inferred_all = list(node.infer()) + assert len(inferred_all) == 1 + inferred = inferred_all[0] + assert isinstance(inferred, astroid.Instance) + + # Should be able to infer that the "value" attr is present on both types + val = inferred.getattr("value")[0] + assert isinstance(val, astroid.AssignAttr) + + # Sanity check: nonexistent attr is not inferred + with self.assertRaises(AttributeInferenceError): + inferred.getattr("bad_attr") def test_namedtuple_nested_class(self): result = builder.extract_node( From 6bbb84071b8f57e492f0c527462912b95aae5179 Mon Sep 17 00:00:00 2001 From: Colin Atkinson Date: Mon, 7 Mar 2022 06:25:49 -0500 Subject: [PATCH 2/2] Improve NewType inference for string forward refs --- astroid/brain/brain_typing.py | 111 ++++++++++++++++++++--- tests/unittest_brain.py | 166 ++++++++++++++++++++++++++++++++++ 2 files changed, 263 insertions(+), 14 deletions(-) diff --git a/astroid/brain/brain_typing.py b/astroid/brain/brain_typing.py index bda4198f21..1a3178df24 100644 --- a/astroid/brain/brain_typing.py +++ b/astroid/brain/brain_typing.py @@ -14,6 +14,7 @@ from astroid.builder import _extract_single_node from astroid.const import PY38_PLUS, PY39_PLUS from astroid.exceptions import ( + AstroidImportError, AttributeInferenceError, InferenceError, UseInferenceDefault, @@ -171,27 +172,109 @@ def infer_typing_newtype( ) new_node.parent = node.parent - # Base type arg is a normal reference, so no need to do special lookups - if not isinstance(base, nodes.Const): - new_node.postinit( - bases=[base], body=new_node.body, decorators=new_node.decorators - ) + new_bases: list[NodeNG] = [] - # If the base type is given as a string (e.g. for a forward reference), - # make a naive attempt to find the corresponding node. - # Note that this will not work with imported types. - if isinstance(base, nodes.Const) and isinstance(base.value, str): + if not isinstance(base, nodes.Const): + # Base type arg is a normal reference, so no need to do special lookups + new_bases = [base] + elif isinstance(base, nodes.Const) and isinstance(base.value, str): + # If the base type is given as a string (e.g. for a forward reference), + # make a naive attempt to find the corresponding node. _, resolved_base = node.frame().lookup(base_name) if resolved_base: - new_node.postinit( - bases=[resolved_base[0]], - body=new_node.body, - decorators=new_node.decorators, - ) + base_node = resolved_base[0] + + # If the value is from an "import from" statement, follow the import chain + if isinstance(base_node, nodes.ImportFrom): + ctx = ctx.clone() if ctx else context.InferenceContext() + ctx.lookupname = base_name + base_node = next(base_node.infer(context=ctx)) + + new_bases = [base_node] + elif "." in base.value: + possible_base = _try_find_imported_object_from_str(node, base.value, ctx) + if possible_base: + new_bases = [possible_base] + + if new_bases: + new_node.postinit( + bases=new_bases, body=new_node.body, decorators=new_node.decorators + ) return new_node.infer(context=ctx) +def _try_find_imported_object_from_str( + node: nodes.Call, + name: str, + ctx: context.InferenceContext | None, +) -> nodes.NodeNG | None: + for statement_mod_name, _ in _possible_module_object_splits(name): + # Find import statements that may pull in the appropriate modules + # The name used to find this statement may not correspond to the name of the module actually being imported + # For example, "import email.charset" is found by lookup("email") + _, resolved_bases = node.frame().lookup(statement_mod_name) + if not resolved_bases: + continue + + resolved_base = resolved_bases[0] + if isinstance(resolved_base, nodes.Import): + # Extract the names of the module as they are accessed from actual code + scope_names = {(alias or name) for (name, alias) in resolved_base.names} + aliases = {alias: name for (name, alias) in resolved_base.names if alias} + + # Find potential mod_name, obj_name splits that work with the available names + # for the module in this scope + import_targets = [ + (mod_name, obj_name) + for (mod_name, obj_name) in _possible_module_object_splits(name) + if mod_name in scope_names + ] + if not import_targets: + continue + + import_target, name_in_mod = import_targets[0] + import_target = aliases.get(import_target, import_target) + + # Try to import the module and find the object in it + try: + resolved_mod: nodes.Module = resolved_base.do_import_module( + import_target + ) + except AstroidImportError: + # If the module doesn't actually exist, try the next option + continue + + # Try to find the appropriate ClassDef or other such node in the target module + _, object_results_in_mod = resolved_mod.lookup(name_in_mod) + if not object_results_in_mod: + continue + + base_node = object_results_in_mod[0] + + # If the value is from an "import from" statement, follow the import chain + if isinstance(base_node, nodes.ImportFrom): + ctx = ctx.clone() if ctx else context.InferenceContext() + ctx.lookupname = name_in_mod + base_node = next(base_node.infer(context=ctx)) + + return base_node + + return None + + +def _possible_module_object_splits( + dot_str: str, +) -> Iterator[tuple[str, str]]: + components = dot_str.split(".") + popped = [] + + while components: + popped.append(components.pop()) + + yield ".".join(components), ".".join(reversed(popped)) + + def _looks_like_typing_subscript(node): """Try to figure out if a Subscript node *might* be a typing-related subscript""" if isinstance(node, Name): diff --git a/tests/unittest_brain.py b/tests/unittest_brain.py index 0a6b83b45b..5af34ead44 100644 --- a/tests/unittest_brain.py +++ b/tests/unittest_brain.py @@ -1837,6 +1837,172 @@ def _verify_node_has_expected_attr(self, node: nodes.NodeNG) -> None: with self.assertRaises(AttributeInferenceError): inferred.getattr("bad_attr") + def test_typing_newtype_forward_reference_imported(self) -> None: + all_ast_nodes = builder.extract_node( + """ + from typing import NewType + + A = NewType("A", "decimal.Decimal") + B = NewType("B", "decimal_mod_alias.Decimal") + C = NewType("C", "Decimal") + D = NewType("D", "DecimalAlias") + + import decimal + import decimal as decimal_mod_alias + from decimal import Decimal + from decimal import Decimal as DecimalAlias + + Decimal #@ + + a = A(decimal.Decimal(2)) + a #@ + b = B(decimal_mod_alias.Decimal(2)) + b #@ + c = C(Decimal(2)) + c #@ + d = D(DecimalAlias(2)) + d #@ + """ + ) + assert isinstance(all_ast_nodes, list) + + real_dec, *ast_nodes = all_ast_nodes + + real_quantize = next(real_dec.infer()).getattr("quantize") + + for node in ast_nodes: + all_inferred = list(node.infer()) + assert len(all_inferred) == 1 + inferred = all_inferred[0] + assert isinstance(inferred, astroid.Instance) + + assert inferred.getattr("quantize") == real_quantize + + def test_typing_newtype_forward_ref_bad_base(self) -> None: + ast_nodes = builder.extract_node( + """ + from typing import NewType + + A = NewType("A", "DoesntExist") + + a = A() + a #@ + + # Valid name, but not actually imported + B = NewType("B", "decimal.Decimal") + + b = B() + b #@ + + # AST works out, but can't import the module + import not_a_real_module + + C = NewType("C", "not_a_real_module.SomeClass") + c = C() + c #@ + + # Real module, fake base class name + import email.charset + + D = NewType("D", "email.charset.BadClassRef") + d = D() + d #@ + + # Real module, but aliased differently than used + import email.header as header_mod + + E = NewType("E", "email.header.Header") + e = E(header_mod.Header()) + e #@ + """ + ) + assert isinstance(ast_nodes, list) + + for ast_node in ast_nodes: + inferred = next(ast_node.infer()) + + with self.assertRaises(astroid.AttributeInferenceError): + inferred.getattr("value") + + def test_typing_newtype_forward_ref_nested_module(self) -> None: + ast_nodes = builder.extract_node( + """ + from typing import NewType + + A = NewType("A", "email.charset.Charset") + B = NewType("B", "charset.Charset") + + # header is unused in both cases, but verifies that module name is properly checked + import email.header, email.charset + from email import header, charset + + real = charset.Charset() + real #@ + + a = A(email.charset.Charset()) + a #@ + + b = B(charset.Charset()) + """ + ) + assert isinstance(ast_nodes, list) + + real, *newtypes = ast_nodes + + real_inferred_all = list(real.infer()) + assert len(real_inferred_all) == 1 + real_inferred = real_inferred_all[0] + + real_method = real_inferred.getattr("get_body_encoding") + + for newtype_node in newtypes: + newtype_inferred_all = list(newtype_node.infer()) + assert len(newtype_inferred_all) == 1 + newtype_inferred = newtype_inferred_all[0] + + newtype_method = newtype_inferred.getattr("get_body_encoding") + + assert real_method == newtype_method + + def test_typing_newtype_forward_ref_nested_class(self) -> None: + ast_nodes = builder.extract_node( + """ + from typing import NewType + + A = NewType("A", "SomeClass.Nested") + + class SomeClass: + class Nested: + def method(self) -> None: + pass + + real = SomeClass.Nested() + real #@ + + a = A(SomeClass.Nested()) + a #@ + """ + ) + assert isinstance(ast_nodes, list) + + real, newtype = ast_nodes + + real_all_inferred = list(real.infer()) + assert len(real_all_inferred) == 1 + real_inferred = real_all_inferred[0] + real_method = real_inferred.getattr("method") + + newtype_all_inferred = list(newtype.infer()) + assert len(newtype_all_inferred) == 1 + newtype_inferred = newtype_all_inferred[0] + + # This could theoretically work, but for now just here to check that + # the "forward-declared module" inference doesn't totally break things + with self.assertRaises(astroid.AttributeInferenceError): + newtype_method = newtype_inferred.getattr("method") + + assert real_method == newtype_method + def test_namedtuple_nested_class(self): result = builder.extract_node( """