From 98c3cc8a909beb12da31f8b5369c5e1d3504340f Mon Sep 17 00:00:00 2001 From: Hana Joo Date: Tue, 22 Oct 2024 10:34:27 -0700 Subject: [PATCH] Annotate type on abstract/_function_base.py, and also correct some types on _base.py and _classes.py that were incorrect. Also remove the actual return statement from `SignedFunction._check_paramspec_args` as there was only a single place being called, and the return was not being used. PiperOrigin-RevId: 688604582 --- pytype/abstract/_base.py | 5 +- pytype/abstract/_classes.py | 2 +- pytype/abstract/_function_base.py | 268 +++++++++++++++++++-------- pytype/overlays/dataclass_overlay.py | 4 +- 4 files changed, 199 insertions(+), 80 deletions(-) diff --git a/pytype/abstract/_base.py b/pytype/abstract/_base.py index 7ee644dc7..306812e0a 100644 --- a/pytype/abstract/_base.py +++ b/pytype/abstract/_base.py @@ -183,7 +183,10 @@ def property_get( return self def get_special_attribute( - self, unused_node: "cfg.CFGNode", name: str, unused_valself: "cfg.Binding" + self, + unused_node: "cfg.CFGNode", + name: str, + unused_valself: "cfg.Variable", ) -> "cfg.Variable | None": """Fetch a special attribute (e.g., __get__, __iter__).""" if name == "__class__": diff --git a/pytype/abstract/_classes.py b/pytype/abstract/_classes.py index ca8bfc383..85215ccb3 100644 --- a/pytype/abstract/_classes.py +++ b/pytype/abstract/_classes.py @@ -1265,7 +1265,7 @@ def getitem_slot( ) def get_special_attribute( - self, node: cfg.CFGNode, name: str, valself: cfg.Binding + self, node: cfg.CFGNode, name: str, valself: cfg.Variable ) -> cfg.Variable | None: if ( valself diff --git a/pytype/abstract/_function_base.py b/pytype/abstract/_function_base.py index b47dc497e..0034a4702 100644 --- a/pytype/abstract/_function_base.py +++ b/pytype/abstract/_function_base.py @@ -1,9 +1,11 @@ """Base abstract representations of functions.""" +from collections.abc import Callable, Generator, Sequence import contextlib import inspect import itertools import logging +from typing import Any, TYPE_CHECKING from pytype.abstract import _base from pytype.abstract import _classes @@ -16,9 +18,16 @@ from pytype.errors import error_types from pytype.types import types -log = logging.getLogger(__name__) +log: logging.Logger = logging.getLogger(__name__) _isinstance = abstract_utils._isinstance # pylint: disable=protected-access +if TYPE_CHECKING: + from pytype import context # pylint: disable=g-bad-import-order,g-import-not-at-top + from pytype import datatypes # pylint: disable=g-bad-import-order,g-import-not-at-top + from pytype import matcher # pylint: disable=g-bad-import-order,g-import-not-at-top + from pytype.abstract import _interpreter_function # pylint: disable=g-bad-import-order,g-import-not-at-top + from pytype.typegraph import cfg # pylint: disable=g-bad-import-order,g-import-not-at-top + class Function(_instance_base.SimpleValue, types.Function): """Base class for function objects (NativeFunction, InterpreterFunction). @@ -30,7 +39,7 @@ class Function(_instance_base.SimpleValue, types.Function): bound_class: type["BoundFunction"] - def __init__(self, name, ctx): + def __init__(self, name: str, ctx: "context.Context") -> None: super().__init__(name, ctx) self.cls = _classes.FunctionPyTDClass(self, ctx) self.is_attribute_of_class = False @@ -43,7 +52,9 @@ def __init__(self, name, ctx): self.ctx.root_node, name ) - def property_get(self, callself, is_class=False): + def property_get( + self, callself: "cfg.Variable", is_class: bool = False + ) -> "BoundFunction|Function": if self.name == "__new__" or not callself or is_class: return self self.is_attribute_of_class = True @@ -51,9 +62,9 @@ def property_get(self, callself, is_class=False): # that would be tied into a BoundFunction instance. However, those # Variables aren't necessarily visible from other parts of the CFG binding # this function. See test_duplicate_getproperty() in tests/test_flow.py. - return self.bound_class(callself, self) + return self.bound_class(callself, self) # pytype: disable=wrong-arg-types - def _get_cell_variable_name(self, var): + def _get_cell_variable_name(self, var: "cfg.Variable") -> str | None: """Get the python variable name of a pytype Variable.""" f = self.ctx.vm.frame if not f: @@ -64,7 +75,13 @@ def _get_cell_variable_name(self, var): return name return None - def match_args(self, node, args, alias_map=None, match_all_views=False): + def match_args( + self, + node: "cfg.CFGNode", + args: function.Args, + alias_map: "datatypes.UnionFind | None" = None, + match_all_views: bool = False, + ) -> "list[matcher.GoodMatch]": """Check whether the given arguments can match the function signature.""" for a in args.posargs: if not a.bindings: @@ -75,13 +92,21 @@ def match_args(self, node, args, alias_map=None, match_all_views=False): raise error_types.UndefinedParameterError(name) return self._match_args_sequentially(node, args, alias_map, match_all_views) - def _match_args_sequentially(self, node, args, alias_map, match_all_views): + def _match_args_sequentially( + self, + node: "cfg.CFGNode", + args: function.Args, + alias_map: "datatypes.UnionFind | None", + match_all_views: bool, + ) -> "list[matcher.GoodMatch]": raise NotImplementedError(self.__class__.__name__) - def __repr__(self): + def __repr__(self) -> str: return self.full_name + "(...)" - def _extract_defaults(self, defaults_var): + def _extract_defaults( + self, defaults_var: "cfg.Variable" + ) -> "tuple[cfg.Variable, ...] | None": """Extracts defaults from a Variable, used by set_function_defaults. Args: @@ -122,7 +147,7 @@ def _extract_defaults(self, defaults_var): def set_function_defaults(self, node, defaults_var): raise NotImplementedError(self.__class__.__name__) - def update_signature_scope(self, cls): + def update_signature_scope(self, cls: _classes.InterpreterClass) -> None: return @@ -135,25 +160,38 @@ class NativeFunction(Function): ctx: context.Context instance. """ - def __init__(self, name, func, ctx): + def __init__(self, name: str, func: Function, ctx: "context.Context") -> None: super().__init__(name, ctx) self.func = func self.bound_class = lambda callself, underlying: self - def argcount(self, _): - return self.func.func_code.argcount + def argcount(self, _: "cfg.CFGNode") -> int: + return self.func.func_code.argcount # pytype: disable=attribute-error - def call(self, node, func, args, alias_map=None): + def call( + self, + node: "cfg.CFGNode", + func: Function, + args: function.Args, + alias_map: "datatypes.UnionFind | None" = None, + ): sig = None - if isinstance(self.func.__self__, _classes.CallableClass): - sig = function.Signature.from_callable(self.func.__self__) + if isinstance( + self.func.__self__, # pytype: disable=attribute-error + _classes.CallableClass, + ): + sig = function.Signature.from_callable( + self.func.__self__ # pytype: disable=attribute-error + ) args = args.simplify(node, self.ctx, match_signature=sig) posargs = [u.AssignToNewVariable(node) for u in args.posargs] namedargs = { k: u.AssignToNewVariable(node) for k, u in args.namedargs.items() } try: - inspect.signature(self.func).bind(node, *posargs, **namedargs) + inspect.signature(self.func).bind( + node, *posargs, **namedargs + ) # pytype: disable=wrong-arg-types except ValueError as e: # Happens for, e.g., # def f((x, y)): pass @@ -201,20 +239,30 @@ def call(self, node, func, args, alias_map=None): ) sig = function.Signature.from_param_names(self.name, argnames) raise error_types.DuplicateKeyword(sig, args, self.ctx, "self") - return self.func(node, *posargs, **namedargs) + return self.func( # pytype: disable=not-callable + node, *posargs, **namedargs + ) def get_positional_names(self): - code = self.func.func_code + # TODO: b/350643999 - this is the only place that func_code is used, + # find out what the type of this is and delete the dead code if not used. + code = self.func.func_code # pytype: disable=attribute-error return list(code.varnames[: code.argcount]) - def property_get(self, callself, is_class=False): + def property_get( + self, callself: "cfg.Variable", is_class: bool = False + ) -> "NativeFunction": return self class BoundFunction(_base.BaseValue): """An function type which has had an argument bound into it.""" - def __init__(self, callself, underlying): + def __init__( + self, + callself: "cfg.Variable", + underlying: "_interpreter_function.InterpreterFunction", + ) -> None: super().__init__(underlying.name, underlying.ctx) self.cls = _classes.FunctionPyTDClass(self, self.ctx) self._callself = callself @@ -237,7 +285,7 @@ def __init__(self, callself, underlying): else: self.alias_map = None - def _get_self_annot(self, callself): + def _get_self_annot(self, callself: "cfg.Variable") -> _base.BaseValue: if isinstance(self.underlying, SignedFunction): self_type = self.underlying.get_self_type_param() else: @@ -251,18 +299,24 @@ def _get_self_annot(self, callself): else: return self_type - def argcount(self, node): + def argcount(self, node: "cfg.CFGNode") -> int: return self.underlying.argcount(node) - 1 # account for self @property - def signature(self): + def signature(self) -> function.Signature: return self.underlying.signature.drop_first_parameter() @property - def callself(self): + def callself(self) -> "cfg.Variable": return self._callself - def call(self, node, func, args, alias_map=None): + def call( + self, + node: "cfg.CFGNode", + func: "cfg.Binding", + args: function.Args, + alias_map: "datatypes.UnionFind | None" = None, + ) -> "tuple[cfg.CFGNode, cfg.Variable]": if self.name.endswith(".__init__"): self.ctx.callself_stack.append(self._callself) # The "self" parameter is automatically added to the list of arguments, but @@ -300,28 +354,30 @@ def call(self, node, func, args, alias_map=None): self.ctx.callself_stack.pop() return node, ret - def get_positional_names(self): + def get_positional_names(self) -> Sequence[str]: return self.underlying.get_positional_names() - def has_varargs(self): + def has_varargs(self) -> bool: return self.underlying.has_varargs() - def has_kwargs(self): + def has_kwargs(self) -> bool: return self.underlying.has_kwargs() @property - def is_abstract(self): + def is_abstract(self) -> bool: return self.underlying.is_abstract @is_abstract.setter - def is_abstract(self, value): + def is_abstract(self, value: bool) -> None: self.underlying.is_abstract = value @property - def is_classmethod(self): + def is_classmethod(self) -> bool: return self.underlying.is_classmethod - def repr_names(self, callself_repr=None): + def repr_names( + self, callself_repr: "Callable[[cfg.Variable], str] | None" = None + ) -> Sequence[str]: """Names to use in the bound function's string representation. This function can return multiple names because there may be multiple @@ -345,10 +401,12 @@ def repr_names(self, callself_repr=None): underlying = underlying.split(".", 1)[-1] return [callself + "." + underlying for callself in callself_names] - def __repr__(self): + def __repr__(self) -> str: return self.repr_names()[0] + "(...)" - def get_special_attribute(self, node, name, valself): + def get_special_attribute( + self, node: "cfg.CFGNode", name: str, valself: "cfg.Variable" + ) -> "cfg.Variable | None": if name == "__self__": return self.callself elif name == "__func__": @@ -360,34 +418,37 @@ class BoundInterpreterFunction(BoundFunction): """The method flavor of InterpreterFunction.""" @contextlib.contextmanager - def record_calls(self): + def record_calls(self) -> Generator[None, None, None]: with self.underlying.record_calls(): yield + # TODO: b/350643999 - figure out the return type def get_first_opcode(self): return self.underlying.code.get_first_opcode(skip_noop=True) @property - def has_overloads(self): + def has_overloads(self) -> bool: return self.underlying.has_overloads @property - def is_overload(self): + def is_overload(self) -> bool: return self.underlying.is_overload @is_overload.setter - def is_overload(self, value): - self.underlying.is_overload = value + def is_overload(self, is_overload: bool) -> None: + self.underlying.is_overload = is_overload @property def defaults(self): return self.underlying.defaults - def iter_signature_functions(self): + def iter_signature_functions( + self, + ) -> "Generator[BoundInterpreterFunction, None, None]": for f in self.underlying.iter_signature_functions(): yield self.underlying.bound_class(self._callself, f) - def reset_overloads(self): + def reset_overloads(self) -> contextlib._GeneratorContextManager: return self.underlying.reset_overloads() @@ -398,7 +459,13 @@ class BoundPyTDFunction(BoundFunction): class ClassMethod(_base.BaseValue): """Implements @classmethod methods in pyi.""" - def __init__(self, name, method, callself, ctx): + def __init__( + self, + name: str, + method: "_interpreter_function.InterpreterFunction", + callself: "cfg.Variable", + ctx: "context.Context", + ) -> None: super().__init__(name, ctx) self.cls = self.ctx.convert.function_type self.method = method @@ -407,25 +474,33 @@ def __init__(self, name, method, callself, ctx): self._callcls = callself self.signatures = self.method.signatures - def call(self, node, func, args, alias_map=None): + def call( + self, + node: "cfg.CFGNode", + func: "cfg.Binding", + args: function.Args, + alias_map: "datatypes.UnionFind | None" = None, + ) -> "tuple[cfg.CFGNode, cfg.Variable]": return self.method.call( node, func, args.replace(posargs=(self._callcls,) + args.posargs) ) - def to_bound_function(self): + def to_bound_function(self) -> BoundPyTDFunction: return BoundPyTDFunction(self._callcls, self.method) class StaticMethod(_base.BaseValue): """Implements @staticmethod methods in pyi.""" - def __init__(self, name, method, _, ctx): + def __init__( + self, name: str, method: Function, _, ctx: "context.Context" + ) -> None: super().__init__(name, ctx) self.cls = self.ctx.convert.function_type self.method = method self.signatures = self.method.signatures - def call(self, *args, **kwargs): + def call(self, *args, **kwargs) -> "tuple[cfg.CFGNode, cfg.Variable]": return self.method.call(*args, **kwargs) @@ -436,14 +511,26 @@ class Property(_base.BaseValue): resolved as a function, not as a constant. """ - def __init__(self, name, method, callself, ctx): + def __init__( + self, + name: str, + method: Function, + callself: "cfg.Variable", + ctx: "context.Context", + ) -> None: super().__init__(name, ctx) self.cls = self.ctx.convert.function_type self.method = method self._callself = callself self.signatures = self.method.signatures - def call(self, node, func, args, alias_map=None): + def call( + self, + node: "cfg.CFGNode", + func: "cfg.Binding | None", + args: function.Args, + alias_map: "datatypes.UnionFind | None" = None, + ) -> "tuple[cfg.CFGNode, cfg.Variable]": func = func or self.to_binding(node) args = args or function.Args(posargs=(self._callself,)) return self.method.call(node, func, args.replace(posargs=(self._callself,))) @@ -455,7 +542,9 @@ class SignedFunction(Function): Subclasses should define call(self, node, f, args) and set self.bound_class. """ - def __init__(self, signature, ctx): + def __init__( + self, signature: function.Signature, ctx: "context.Context" + ) -> None: # We should only instantiate subclasses of SignedFunction assert self.__class__ != SignedFunction super().__init__(signature.name, ctx) @@ -465,11 +554,13 @@ def __init__(self, signature, ctx): self._has_self_annot = False @property - def has_self_annot(self): + def has_self_annot(self) -> bool: return self._has_self_annot @contextlib.contextmanager - def set_self_annot(self, annot_class: _base.BaseValue | None): + def set_self_annot( + self, annot_class: _base.BaseValue | None + ) -> Generator[None, None, None]: """Set the annotation for `self` in a class.""" self_name = self.signature.param_names[0] old_self = self.signature.annotations.get(self_name) @@ -495,21 +586,28 @@ def get_self_type_param(self): return param return None - def argcount(self, _): + def argcount(self, _: "cfg.CFGNode") -> int: return len(self.signature.param_names) - def get_nondefault_params(self): + def get_nondefault_params(self) -> Generator[tuple[str, bool], None, None]: return ( (n, n in self.signature.kwonly_params) for n in self.signature.param_names if n not in self.signature.defaults ) - def match_and_map_args(self, node, args, alias_map): + def match_and_map_args( + self, + node: "cfg.CFGNode", + args: function.Args, + alias_map: "datatypes.UnionFind | None" = None, + ) -> "tuple[list[matcher.GoodMatch], dict[str, cfg.Variable]]": """Calls match_args() and _map_args().""" return self.match_args(node, args, alias_map), self._map_args(node, args) - def _map_args(self, node, args): + def _map_args( + self, node: "cfg.CFGNode", args: function.Args + ) -> "dict[str, cfg.Variable]": """Map call args to function args. This emulates how Python would map arguments of function calls. It takes @@ -585,7 +683,7 @@ def _map_args(self, node, args): callargs[kwargs_name] = k.to_variable(node) return callargs - def _check_paramspec_args(self, args): + def _check_paramspec_args(self, args: function.Args) -> None: args_pspec, kwargs_pspec = None, None for name, _, formal in self.signature.iter_args(args): if not _isinstance(formal, "ParameterizedClass"): @@ -606,14 +704,20 @@ def _check_paramspec_args(self, args): and args_pspec.paramspec == kwargs_pspec.paramspec ) if valid: - return args_pspec.paramspec + return else: self.ctx.errorlog.paramspec_error( self.ctx.vm.frames, "ParamSpec.args and ParamSpec.kwargs must be used together", ) - def _match_args_sequentially(self, node, args, alias_map, match_all_views): + def _match_args_sequentially( + self, + node: "cfg.CFGNode", + args: function.Args, + alias_map: "datatypes.UnionFind | None", + match_all_views: bool, + ) -> "list[matcher.GoodMatch]": args_to_match = [] self._check_paramspec_args(args) for name, arg, formal in self.signature.iter_args(args): @@ -635,10 +739,12 @@ def _match_args_sequentially(self, node, args, alias_map, match_all_views): ) return [m.subst for m in matches] - def get_first_opcode(self): + def get_first_opcode(self) -> None: return None - def set_function_defaults(self, node, defaults_var): + def set_function_defaults( + self, node: "cfg.CFGNode", defaults_var: "cfg.Variable" + ) -> None: """Attempts to set default arguments of a function. If defaults_var is not an unambiguous tuple (i.e. one that can be processed @@ -659,7 +765,9 @@ def set_function_defaults(self, node, defaults_var): defaults = dict(zip(self.signature.param_names[-len(defaults) :], defaults)) self.signature.defaults = defaults - def _mutations_generator(self, node, first_arg, substs): + def _mutations_generator( + self, node: "cfg.CFGNode", first_arg: "cfg.Variable", substs + ) -> Callable[[], Generator[function.Mutation, None, None]]: def generator(): """Yields mutations.""" if ( @@ -697,7 +805,7 @@ def generator(): # extra time. return generator - def update_signature_scope(self, cls): + def update_signature_scope(self, cls: _classes.InterpreterClass) -> None: self.signature.excluded_types.update([t.name for t in cls.template]) self.signature.add_scope(cls) @@ -709,23 +817,25 @@ class SimpleFunction(SignedFunction): record calls or try to infer types. """ - def __init__(self, signature, ctx): + def __init__( + self, signature: function.Signature, ctx: "context.Context" + ) -> None: super().__init__(signature, ctx) self.bound_class = BoundFunction @classmethod def build( cls, - name, - param_names, - posonly_count, - varargs_name, - kwonly_params, - kwargs_name, - defaults, - annotations, - ctx, - ): + name: str, + param_names: tuple[str, ...], + posonly_count: int, + varargs_name: str | None, + kwonly_params: tuple[str, ...], + kwargs_name: str | None, + defaults: "dict[str, cfg.Variable]", + annotations: dict[str, Any], + ctx: "context.Context", + ) -> "SimpleFunction": """Returns a SimpleFunction. Args: @@ -761,7 +871,7 @@ def build( ) return cls(signature, ctx) - def _skip_parameter_matching(self): + def _skip_parameter_matching(self) -> bool: """Check whether we should skip parameter matching. This is use to skip parameter matching for function calls in the context of @@ -783,7 +893,13 @@ def _skip_parameter_matching(self): return self.signature.has_return_annotation or self.full_name == "__init__" - def call(self, node, func, args, alias_map=None): + def call( + self, + node: "cfg.CFGNode", + func: "cfg.Binding|None", + args: function.Args, + alias_map: "datatypes.UnionFind | None" = None, + ) -> "tuple[cfg.CFGNode, cfg.Variable]": args = args.simplify(node, self.ctx) callargs = self._map_args(node, args) substs = [] diff --git a/pytype/overlays/dataclass_overlay.py b/pytype/overlays/dataclass_overlay.py index 49c0df868..590ce189b 100644 --- a/pytype/overlays/dataclass_overlay.py +++ b/pytype/overlays/dataclass_overlay.py @@ -295,10 +295,10 @@ def _match_args_sequentially(self, node, args, alias_map, match_all_views): default = self.ctx.new_unsolvable(node) replace = abstract.SimpleFunction.build( name=self.name, - param_names=["obj"], + param_names=("obj",), posonly_count=1, varargs_name=None, - kwonly_params=[f.name for f in fields], + kwonly_params=tuple(f.name for f in fields), kwargs_name=None, defaults={f.name: default for f in fields}, annotations={f.name: f.typ for f in fields},