From 5a7f76c0340e6dc5d9c47421c088df27e5ddf789 Mon Sep 17 00:00:00 2001 From: Hana Joo Date: Mon, 28 Oct 2024 06:27:18 -0700 Subject: [PATCH] Add type annotations to abstract/function.py. Also fix all other type annotations that were wrong, or add suppressions to where pytype is wrong or cannot figure out the types due to limitations. PiperOrigin-RevId: 690586637 --- pytype/abstract/_base.py | 2 +- pytype/abstract/_function_base.py | 2 +- pytype/abstract/_interpreter_function.py | 12 +- pytype/abstract/abstract_utils.py | 2 +- pytype/abstract/function.py | 283 ++++++++++++++++------- pytype/typegraph/cfg_utils.py | 2 +- pytype/vm_utils.py | 2 +- 7 files changed, 206 insertions(+), 99 deletions(-) diff --git a/pytype/abstract/_base.py b/pytype/abstract/_base.py index ac2b918d6..3de542fe8 100644 --- a/pytype/abstract/_base.py +++ b/pytype/abstract/_base.py @@ -142,7 +142,7 @@ def get_instance_type_parameter( return self.ctx.new_unsolvable(node) def get_formal_type_parameter( - self, t # TODO: b/350643999 - Figure out the type of 't'. + self, t: "BaseValue" ) -> "BaseValue": """Get the class's type for the type parameter. diff --git a/pytype/abstract/_function_base.py b/pytype/abstract/_function_base.py index 86529f4c1..8b9b763d2 100644 --- a/pytype/abstract/_function_base.py +++ b/pytype/abstract/_function_base.py @@ -699,7 +699,7 @@ def _check_paramspec_args(self, args: function.Args) -> None: for name, _, formal in self.signature.iter_args(args): if not _isinstance(formal, "ParameterizedClass"): continue - params = formal.get_formal_type_parameters() + params = formal.get_formal_type_parameters() # pytype: disable=attribute-error if name == self.signature.varargs_name: for param in params.values(): if _isinstance(param, "ParamSpecArgs"): diff --git a/pytype/abstract/_interpreter_function.py b/pytype/abstract/_interpreter_function.py index 655632915..1fa37769b 100644 --- a/pytype/abstract/_interpreter_function.py +++ b/pytype/abstract/_interpreter_function.py @@ -212,7 +212,7 @@ def __init__( defaults, kw_defaults, closure, - annotations: "dict[str, _base.BaseValue]", + annotations: dict[str, pytd.Type], overloads, ctx: "context.Context", ) -> None: @@ -328,7 +328,7 @@ def _check_signature(self) -> None: ) def _build_signature( - self, name: str, annotations: "dict[str, _base.BaseValue]" + self, name: str, annotations: dict[str, pytd.Type] ) -> function.Signature: """Build a function.Signature object representing this function.""" vararg_name = None @@ -564,7 +564,7 @@ def _paramspec_signature( self, callable_type: _classes.ParameterizedClass, substs: "list[matcher.GoodMatch]", - ) -> str | None: + ) -> function.Signature | None: # Unpack the paramspec substitution we have created in the matcher. rhs = callable_type.formal_type_parameters[0] if _isinstance(rhs, "Concatenate"): @@ -593,15 +593,15 @@ def _handle_paramspec( if not sig.has_return_annotation: return retval = sig.annotations["return"] - if not (_isinstance(retval, "CallableClass") and retval.has_paramspec()): + if not (_isinstance(retval, "CallableClass") and retval.has_paramspec()): # pytype: disable=attribute-error return ret_sig = self._paramspec_signature(retval, substs) if ret_sig: ret_annot = self.ctx.pytd_convert.signature_to_callable(ret_sig) annotations["return"] = ret_annot for name, _, annot in sig.iter_args(callargs): - if _isinstance(annot, "CallableClass") and annot.has_paramspec(): - param_sig = self._paramspec_signature(annot, substs) + if _isinstance(annot, "CallableClass") and annot.has_paramspec(): # pytype: disable=attribute-error + param_sig = self._paramspec_signature(annot, substs) # pytype: disable=wrong-arg-types if param_sig: param_annot = self.ctx.pytd_convert.signature_to_callable(param_sig) annotations[name] = param_annot diff --git a/pytype/abstract/abstract_utils.py b/pytype/abstract/abstract_utils.py index 4945d31c1..8ac7254bc 100644 --- a/pytype/abstract/abstract_utils.py +++ b/pytype/abstract/abstract_utils.py @@ -1020,7 +1020,7 @@ def get_generic_type( def with_empty_substitutions( subst: datatypes.AliasingDict[str, cfg.Variable], - pytd_type: pytd.Signature, + pytd_type: "_base.BaseValue", node: cfg.CFGNode, ctx: "context.Context", ) -> datatypes.AliasingDict[str, cfg.Variable]: diff --git a/pytype/abstract/function.py b/pytype/abstract/function.py index 3c9263c97..762981673 100644 --- a/pytype/abstract/function.py +++ b/pytype/abstract/function.py @@ -2,10 +2,11 @@ import abc import collections +from collections.abc import Generator, Iterable, Sequence import dataclasses import itertools import logging -from typing import Any, TypeVar, cast +from typing import Any, TYPE_CHECKING, TypeVar, cast import attrs from pytype import datatypes @@ -19,29 +20,49 @@ from pytype.typegraph import cfg_utils from pytype.types import types -log = logging.getLogger(__name__) +if TYPE_CHECKING: + from pytype import context # pylint: disable=g-bad-import-order,g-import-not-at-top + from pytype import state # pylint: disable=g-bad-import-order,g-import-not-at-top + from pytype.abstract import _classes # pylint: disable=g-bad-import-order,g-import-not-at-top + from pytype.abstract import _function_base # pylint: disable=g-bad-import-order,g-import-not-at-top + from pytype.abstract import _singletons # pylint: disable=g-bad-import-order,g-import-not-at-top + from pytype.abstract import _typing # pylint: disable=g-bad-import-order,g-import-not-at-top + from pytype.abstract import _instance_base # pylint: disable=g-bad-import-order,g-import-not-at-top + from pytype.abstract import _interpreter_function # pylint: disable=g-bad-import-order,g-import-not-at-top + from pytype.pyc import opcodes # pylint: disable=g-bad-import-order,g-import-not-at-top + + +log: logging.Logger = logging.getLogger(__name__) _isinstance = abstract_utils._isinstance # pylint: disable=protected-access _make = abstract_utils._make # pylint: disable=protected-access -def argname(i): +def argname(i: int) -> str: """Get a name for an unnamed positional argument, given its position.""" return "_" + str(i) -def get_signatures(func): +def get_signatures(func: "_function_base.Function") -> "list[Signature]": """Gets the given function's signatures.""" if _isinstance(func, "PyTDFunction"): - return [sig.signature for sig in func.signatures] + # TODO: b/350643999 - There's something going wrong here. This is + # implemented as a non-property method and it's not being called but yet + # this code somehow doesn't get complained in tests. This will either crash + # or not run upon reaching this part. + return [sig.signature for sig in func.signatures] # pytype: disable=attribute-error elif _isinstance(func, "InterpreterFunction"): - return [f.signature for f in func.signature_functions()] + f: "_interpreter_function.InterpreterFunction" = func # pytype: disable=annotation-type-mismatch + return [f.signature for f in f.signature_functions()] elif _isinstance(func, "BoundFunction"): - sigs = get_signatures(func.underlying) + f: "_function_base.BoundFunction" = func # pytype: disable=annotation-type-mismatch + sigs = get_signatures(f.underlying) return [sig.drop_first_parameter() for sig in sigs] # drop "self" elif _isinstance(func, ("ClassMethod", "StaticMethod")): - return get_signatures(func.method) + f: "_function_base.ClassMethod | _function_base.StaticMethod" = func # pytype: disable=annotation-type-mismatch + return get_signatures(f.method) elif _isinstance(func, "SignedFunction"): - return [func.signature] + f: "_function_base.SignedFunction" = func # pytype: disable=annotation-type-mismatch + return [f.signature] elif _isinstance(func, "AMBIGUOUS_OR_EMPTY"): return [Signature.from_any()] elif func.__class__.__name__ == "PropertyInstance": @@ -53,7 +74,7 @@ def get_signatures(func): # def f()... return [] elif _isinstance(func.cls, "CallableClass"): - return [Signature.from_callable(func.cls)] + return [Signature.from_callable(func.cls)] # pytype: disable=wrong-arg-types else: unwrapped = abstract_utils.maybe_unwrap_decorated_function(func) if unwrapped: @@ -74,7 +95,7 @@ def get_signatures(func): raise NotImplementedError(func.__class__.__name__) -def _print(t): +def _print(t: _base.BaseValue) -> str: return pytd_utils.Print(t.to_pytd_type_of_instance()) @@ -138,26 +159,28 @@ def __init__( ) @property - def has_return_annotation(self): + def has_return_annotation(self) -> bool: return "return" in self.annotations @property - def has_param_annotations(self): + def has_param_annotations(self) -> bool: return bool(self.annotations.keys() - {"return"}) - def has_default(self, name): + def has_default(self, name: str) -> bool: return name in self.defaults - def add_scope(self, cls): + def add_scope(self, cls: "_classes.InterpreterClass") -> None: """Add scope for type parameters in annotations.""" - annotations = {} + annotations: "dict[str, _base.BaseValue]" = {} for key, val in self.annotations.items(): annotations[key] = val.ctx.annotation_utils.add_scope( val, self.excluded_types, cls ) self.annotations = annotations - def _postprocess_annotation(self, name, annotation): + def _postprocess_annotation( + self, name: str, annotation: _base.BaseValue + ) -> "_classes.ParameterizedClass | _base.BaseValue": """Postprocess the given annotation.""" ctx = annotation.ctx if name == self.varargs_name: @@ -176,13 +199,18 @@ def _postprocess_annotation(self, name, annotation): else: return annotation - def set_annotation(self, name, annotation): + def set_annotation(self, name: str, annotation: _base.BaseValue) -> None: self.annotations[name] = self._postprocess_annotation(name, annotation) - def del_annotation(self, name): + def del_annotation(self, name: str) -> None: del self.annotations[name] # Raises KeyError if annotation does not exist. - def check_type_parameters(self, stack, opcode, is_attribute_of_class): + def check_type_parameters( + self, + stack: "tuple[state.SimpleFrame]", + opcode: "opcodes.Opcode", + is_attribute_of_class: bool, + ) -> None: """Check type parameters in function.""" if not self.annotations: return @@ -247,12 +275,12 @@ def check_type_parameters(self, stack, opcode, is_attribute_of_class): ) ctx.errorlog.invalid_annotation(stack, annot, msg) - def drop_first_parameter(self): + def drop_first_parameter(self) -> "Signature": return self._replace(param_names=self.param_names[1:]) def _make_concatenated_type( self, type1: _base.BaseValue, type2: _base.BaseValue | None - ) -> _base.BaseValue | None: + ) -> "_typing.Concatenate | None": """Concatenates type1 and type2 if possible. If type2 is a ParamSpec or Concatenate object, creates a new Concatenate @@ -291,20 +319,22 @@ def prepend_parameter(self: _SigT, name: str, typ: _base.BaseValue) -> _SigT: annots = {**self.annotations, name: typ} return self._replace(param_names=param_names, annotations=annots) - def mandatory_param_count(self): + def mandatory_param_count(self) -> int: num = len([name for name in self.param_names if name not in self.defaults]) num += len( [name for name in self.kwonly_params if name not in self.defaults] ) return num - def maximum_param_count(self): + def maximum_param_count(self) -> int | None: if self.varargs_name or self.kwargs_name: return None return len(self.param_names) + len(self.kwonly_params) @classmethod - def from_pytd(cls, ctx, name, sig): + def from_pytd( + cls, ctx: "context.Context", name: str, sig: "pytd.Signature" + ) -> "Signature": """Construct an abstract signature from a pytd signature.""" pytd_annotations = [ (p.name, p.type) @@ -345,7 +375,7 @@ def param_to_var(p): ) @classmethod - def from_callable(cls, val): + def from_callable(cls, val: "_classes.CallableClass") -> "Signature": annotations = { argname(i): val.formal_type_parameters[i] for i in range(val.num_args) } @@ -363,7 +393,12 @@ def from_callable(cls, val): ) @classmethod - def from_param_names(cls, name, param_names, kind=pytd.ParameterKind.REGULAR): + def from_param_names( + cls, + name: str, + param_names: Sequence[str], + kind=pytd.ParameterKind.REGULAR, + ) -> "Signature": """Construct a minimal signature from a name and a list of param names.""" names = tuple(param_names) if kind == pytd.ParameterKind.REGULAR: @@ -391,7 +426,7 @@ def from_param_names(cls, name, param_names, kind=pytd.ParameterKind.REGULAR): ) @classmethod - def from_any(cls): + def from_any(cls) -> "Signature": """Treat `Any` as `f(...) -> Any`.""" return cls( name="", @@ -404,14 +439,14 @@ def from_any(cls): annotations={}, ) - def has_param(self, name): + def has_param(self, name: str): return ( name in self.param_names or name in self.kwonly_params or (name == self.varargs_name or name == self.kwargs_name) ) - def insert_varargs_and_kwargs(self, args): + def insert_varargs_and_kwargs(self, args: Iterable[str]): """Insert varargs and kwargs from args into the signature. Args: @@ -436,11 +471,11 @@ def insert_varargs_and_kwargs(self, args): ) return self._replace(param_names=new_param_names) - _ATTRIBUTES = set( + _ATTRIBUTES: set[str] = set( __init__.__code__.co_varnames[: __init__.__code__.co_argcount] ) - {"self", "postprocess_annotations"} - def _replace(self, **kwargs): + def _replace(self: _SigT, **kwargs) -> _SigT: """Returns a copy of the signature with the specified values replaced.""" assert not set(kwargs) - self._ATTRIBUTES for attr in self._ATTRIBUTES: @@ -449,7 +484,11 @@ def _replace(self, **kwargs): kwargs["postprocess_annotations"] = False return type(self)(**kwargs) - def iter_args(self, args): + def iter_args( + self, args: "Args" + ) -> Generator[ + tuple[str, cfg.Variable | None, _base.BaseValue | None], None, None + ]: """Iterates through the given args, attaching names and expected types.""" for i, posarg in enumerate(args.posargs): if i < len(self.param_names): @@ -485,7 +524,7 @@ def iter_args(self, args): self.annotations.get(self.kwargs_name), ) - def check_defaults(self, ctx): + def check_defaults(self, ctx: "context.Context") -> None: """Raises an error if a non-default param follows a default.""" has_default = False for name in self.param_names: @@ -499,7 +538,7 @@ def check_defaults(self, ctx): ctx.errorlog.invalid_function_definition(ctx.vm.stack(), msg) return - def _yield_arguments(self): + def _yield_arguments(self) -> Generator[str, None, None]: """Yield all the function arguments.""" names = list(self.param_names) if self.varargs_name: @@ -517,10 +556,10 @@ def _yield_arguments(self): " = " + default if default else "" ) - def _print_annot(self, name): + def _print_annot(self, name: str) -> str | None: return _print(self.annotations[name]) if name in self.annotations else None - def _print_default(self, name): + def _print_default(self, name: str) -> str | None: if name in self.defaults: values = self.defaults[name].data if len(values) > 1: @@ -530,7 +569,7 @@ def _print_default(self, name): else: return None - def __repr__(self): + def __repr__(self) -> str: args = list(self._yield_arguments()) if self.posonly_count: args = args[: self.posonly_count] + ["/"] + args[self.posonly_count :] @@ -538,14 +577,18 @@ def __repr__(self): ret = self._print_annot("return") return f"def {self.name}({args}) -> {ret if ret else 'Any'}" - def get_self_arg(self, callargs): + def get_self_arg( + self, callargs: dict[str, cfg.Variable] + ) -> cfg.Variable | None: """Returns the 'self' or 'cls' arg, if any.""" if self.param_names and self.param_names[0] in ("self", "cls"): return callargs.get(self.param_names[0]) else: return None - def get_first_arg(self, callargs): + def get_first_arg( + self, callargs: dict[str, cfg.Variable] + ) -> cfg.Variable | None: """Returns the first non-self/cls arg, if any.""" if not self.param_names: return None @@ -557,7 +600,12 @@ def get_first_arg(self, callargs): return None return callargs.get(name) - def populate_annotation_dict(self, annots, ctx, param_names=None): + def populate_annotation_dict( + self, + annots: dict[str, pytd.Type | _singletons.Unsolvable], + ctx: "context.Context", + param_names: tuple[str, ...] | None = None, + ) -> None: """Populate annotation dict with default values.""" if param_names is None: param_names = self.param_names @@ -570,7 +618,9 @@ def populate_annotation_dict(self, annots, ctx, param_names=None): annots[self.kwargs_name] = ctx.convert.dict_type -def _convert_namedargs(namedargs): +def _convert_namedargs( + namedargs: dict[str, cfg.Variable], +) -> dict[str, cfg.Variable]: return {} if namedargs is None else namedargs @@ -593,24 +643,27 @@ class Args: starargs: cfg.Variable | None = None starstarargs: cfg.Variable | None = None - def has_namedargs(self): + def has_namedargs(self) -> bool: return bool(self.namedargs) - def has_non_namedargs(self): + def has_non_namedargs(self) -> bool: return bool(self.posargs or self.starargs or self.starstarargs) - def is_empty(self): + def is_empty(self) -> bool: return not (self.has_namedargs() or self.has_non_namedargs()) - def starargs_as_tuple(self, node, ctx): + def starargs_as_tuple( + self, node: cfg.CFGNode, ctx: "context.Context" + ) -> tuple[Any, ...]: try: - args = self.starargs and abstract_utils.get_atomic_python_constant( - self.starargs, tuple + args: Any | None = ( + self.starargs + and abstract_utils.get_atomic_python_constant(self.starargs, tuple) ) except abstract_utils.ConversionError: args = None if not args: - return args + return args # pytype: disable=bad-return-type return tuple( var if var.bindings else ctx.convert.empty.to_variable(node) for var in args @@ -627,7 +680,13 @@ def starstarargs_as_dict(self): return None return kwdict.pyval - def _expand_typed_star(self, node, star, count, ctx): + def _expand_typed_star( + self, + node: cfg.CFGNode, + star: cfg.Variable, + count: int, + ctx: "context.Context", + ) -> list[cfg.Variable]: """Convert *xs: Sequence[T] -> [T, T, ...].""" if not count: return [] @@ -636,7 +695,13 @@ def _expand_typed_star(self, node, star, count, ctx): p = ctx.new_unsolvable(node) return [p.AssignToNewVariable(node) for _ in range(count)] - def _unpack_and_match_args(self, node, ctx, match_signature, starargs_tuple): + def _unpack_and_match_args( + self, + node: cfg.CFGNode, + ctx: "context.Context", + match_signature: Signature, + starargs_tuple: tuple[cfg.Variable, ...], + ) -> tuple[tuple[cfg.Variable, ...], cfg.Variable | None]: """Match args against a signature with unpacking.""" posargs = self.posargs namedargs = self.namedargs @@ -711,7 +776,12 @@ def _unpack_and_match_args(self, node, ctx, match_signature, starargs_tuple): # We have **kwargs but no *args in the invocation return posargs + tuple(pre), None - def simplify(self, node, ctx, match_signature=None): + def simplify( + self, + node: cfg.CFGNode, + ctx: "context.Context", + match_signature: Signature | None = None, + ) -> "Args": """Try to insert part of *args, **kwargs into posargs / namedargs.""" # TODO(rechen): When we have type information about *args/**kwargs, # we need to check it before doing this simplification. @@ -735,6 +805,7 @@ def simplify(self, node, ctx, match_signature=None): # **kwargs, starstarargs will have is_concrete set to False, so # preserve it as an abstract dict. If not, we just had named args packed # into starstarargs, so set starstarargs to None. + assert starstarargs is not None kwdict = starstarargs.data[0] if _isinstance(kwdict, "Dict") and not kwdict.is_concrete: cls = kwdict.cls @@ -780,7 +851,7 @@ def simplify(self, node, ctx, match_signature=None): simplify(starstarargs), ) - def get_variables(self): + def get_variables(self) -> list[cfg.Variable]: variables = list(self.posargs) + list(self.namedargs.values()) if self.starargs is not None: variables.append(self.starargs) @@ -788,23 +859,23 @@ def get_variables(self): variables.append(self.starstarargs) return variables - def replace_posarg(self, pos, val): + def replace_posarg(self, pos: int, val: cfg.Variable) -> "Args": new_posargs = self.posargs[:pos] + (val,) + self.posargs[pos + 1 :] return self.replace(posargs=new_posargs) - def replace_namedarg(self, name, val): + def replace_namedarg(self, name: str, val: cfg.Variable) -> "Args": new_namedargs = dict(self.namedargs) new_namedargs[name] = val return self.replace(namedargs=new_namedargs) - def delete_namedarg(self, name): + def delete_namedarg(self, name: str) -> "Args": new_namedargs = {k: v for k, v in self.namedargs.items() if k != name} return self.replace(namedargs=new_namedargs) - def replace(self, **kwargs): + def replace(self, **kwargs) -> "Args": return attrs.evolve(self, **kwargs) - def has_opaque_starargs_or_starstarargs(self): + def has_opaque_starargs_or_starstarargs(self) -> bool: return any( arg and not _isinstance(arg, "PythonConstant") for arg in (self.starargs, self.starstarargs) @@ -814,17 +885,23 @@ def has_opaque_starargs_or_starstarargs(self): class ParamSpecMatch(_base.BaseValue): """Match a paramspec against a sig.""" - def __init__(self, paramspec, sig, ctx): + def __init__( + self, paramspec: _base.BaseValue, sig: Signature, ctx: "context.Context" + ) -> None: super().__init__("ParamSpecMatch", ctx) self.paramspec = paramspec self.sig = sig - def instantiate(self, node, container=None): + def instantiate( + self, + node: cfg.CFGNode, + container: "_instance_base.SimpleValue | abstract_utils.DummyContainer | None" = None, + ) -> cfg.Variable: return self.to_variable(node) - def prefix(self): + def prefix(self) -> "tuple[_typing.ParamSpec]": if _isinstance(self.paramspec, "Concatenate"): - return self.paramspec.args + return self.paramspec.args # pytype: disable=attribute-error else: return () @@ -837,14 +914,14 @@ class Mutation: name: str value: cfg.Variable - def __eq__(self, other): + def __eq__(self, other: "Mutation") -> bool: return ( self.instance == other.instance and self.name == other.name and frozenset(self.value.data) == frozenset(other.value.data) ) - def __hash__(self): + def __hash__(self) -> int: return hash((self.instance, self.name, frozenset(self.value.data))) @@ -868,41 +945,51 @@ def get_parameter(self, node, param_name): class AbstractReturnType(_ReturnType): """An abstract return type.""" - def __init__(self, t, ctx): + def __init__(self, t: _base.BaseValue, ctx: "context.Context") -> None: self._type = t self._ctx = ctx @property - def name(self): + def name(self) -> str: return self._type.full_name - def instantiate_parameter(self, node, param_name): + def instantiate_parameter( + self, node: cfg.CFGNode, param_name: str + ) -> tuple[cfg.CFGNode, cfg.Variable]: param = self._type.get_formal_type_parameter(param_name) return self._ctx.vm.init_class(node, param) - def get_parameter(self, node, param_name): + def get_parameter(self, node: cfg.CFGNode, param_name: str): return self._type.get_formal_type_parameter(param_name) class PyTDReturnType(_ReturnType): """A PyTD return type.""" - def __init__(self, t, subst, sources, ctx): + def __init__( + self, + t: _base.BaseValue, + subst: datatypes.AliasingDict[str, cfg.Variable], + sources: list[cfg.Binding], + ctx: "context.Context", + ) -> None: self._type = t self._subst = subst self._sources = sources self._ctx = ctx @property - def name(self): + def name(self) -> str: return self._type.name - def instantiate_parameter(self, node, param_name): + def instantiate_parameter( + self, node: cfg.CFGNode, param_name: str + ) -> cfg.Variable: _, instance_var = self.instantiate(node) instance = abstract_utils.get_atomic_value(instance_var) return instance.get_instance_type_parameter(param_name) - def instantiate(self, node): + def instantiate(self, node: cfg.CFGNode) -> tuple[cfg.CFGNode, cfg.Variable]: """Instantiate the pytd return type.""" # Type parameter values, which are instantiated by the matcher, will end up # in the return value. Since the matcher does not call __init__, we need to @@ -934,12 +1021,14 @@ def instantiate(self, node): ret.AddBinding(self._ctx.convert.empty, [], node) return node, ret - def get_parameter(self, node, param_name): + def get_parameter(self, node: cfg.CFGNode, param_name: str): t = self._ctx.convert.constant_to_value(self._type, self._subst, node) return t.get_formal_type_parameter(param_name) -def _splats_to_any(seq, ctx): +def _splats_to_any( + seq: Sequence[cfg.Variable], ctx: "context.Context" +) -> tuple[cfg.Variable, ...]: return tuple( ctx.new_unsolvable(ctx.root_node) if abstract_utils.is_var_splat(v) else v for v in seq @@ -947,14 +1036,14 @@ def _splats_to_any(seq, ctx): def call_function( - ctx, - node, - func_var, - args, - fallback_to_unsolvable=True, - allow_never=False, - strict_filter=True, -): + ctx: "context.Context", + node: cfg.CFGNode, + func_var: cfg.Variable, + args: Args, + fallback_to_unsolvable: bool = True, + allow_never: bool = False, + strict_filter: bool = True, +) -> tuple[cfg.CFGNode, cfg.Variable]: """Call a function. Args: @@ -1050,7 +1139,12 @@ def call_function( raise error # pylint: disable=raising-bad-type -def match_all_args(ctx, node, func, args): +def match_all_args( + ctx: "context.Context", + node: cfg.CFGNode, + func: "_function_base.NativeFunction|_interpreter_function.InterpreterFunction", + args: "Args", +) -> "tuple[Args, Sequence[tuple[Exception, str, _base.BaseValue]]]": """Call match_args multiple times to find all type errors. Args: @@ -1111,7 +1205,9 @@ def match_all_args(ctx, node, func, args): return args, errors -def has_visible_namedarg(node, args, names): +def has_visible_namedarg( + node: cfg.CFGNode, args: Args, names: set[str] +) -> bool: # Note: this method should be called judiciously, as HasCombination is # potentially very expensive. namedargs = {args.namedargs[name] for name in names} @@ -1123,7 +1219,13 @@ def has_visible_namedarg(node, args, names): return False -def handle_typeguard(node, ret: _ReturnType, first_arg, ctx, func_name=None): +def handle_typeguard( + node: cfg.CFGNode, + ret: _ReturnType, + first_arg: cfg.Variable, + ctx: "context.Context", + func_name: str | None = None, +) -> cfg.Variable | None: """Returns a variable of the return value of a type guard function. Args: @@ -1195,11 +1297,16 @@ def handle_typeguard(node, ret: _ReturnType, first_arg, ctx, func_name=None): return typeguard_return -def build_paramspec_signature(pspec_match, r_args, return_value, ctx): +def build_paramspec_signature( + pspec_match, + r_args: tuple[pytd.TypeU, ...], + return_value: _base.BaseValue, + ctx: "context.Context", +) -> Signature: """Build a signature from a ParamSpecMatch and Callable args.""" - sig = pspec_match.sig + sig: Signature = pspec_match.sig ann = sig.annotations.copy() - ann["return"] = return_value + ann["return"] = return_value # pytype: disable=container-type-mismatch ret_posargs = [] for i, typ in enumerate(r_args): name = f"_{i}" diff --git a/pytype/typegraph/cfg_utils.py b/pytype/typegraph/cfg_utils.py index 238b6c137..d8f846309 100644 --- a/pytype/typegraph/cfg_utils.py +++ b/pytype/typegraph/cfg_utils.py @@ -21,7 +21,7 @@ def variable_product( variables: list[cfg.Variable], -) -> Iterable[tuple[cfg.Variable, ...]]: +) -> Iterable[tuple[cfg.Binding, ...]]: """Take the Cartesian product of a number of Variables. Args: diff --git a/pytype/vm_utils.py b/pytype/vm_utils.py index 87121d741..ce14ea3c1 100644 --- a/pytype/vm_utils.py +++ b/pytype/vm_utils.py @@ -632,7 +632,7 @@ def _check_defaults(node, method, ctx): "Unexpected argument matching error: %s" % e.__class__.__name__ ) from e for e, arg_name, value in errors: - bad_param = e.bad_call.bad_param + bad_param = e.bad_call.bad_param # pytype: disable=attribute-error expected_type = bad_param.typ if value == ctx.convert.ellipsis: # `...` should be a valid default parameter value for overloads.