From efd16783f5e083e33913f7d145fd8837d72f45cf Mon Sep 17 00:00:00 2001 From: Hana Joo Date: Mon, 28 Oct 2024 09:50:15 -0700 Subject: [PATCH] Annotate type on abstract/_typing.py PiperOrigin-RevId: 690649731 --- pytype/abstract/_typing.py | 279 +++++++++++++++++++++--------- pytype/overlays/typing_overlay.py | 4 +- 2 files changed, 202 insertions(+), 81 deletions(-) diff --git a/pytype/abstract/_typing.py b/pytype/abstract/_typing.py index 13acff670..f9a4999b2 100644 --- a/pytype/abstract/_typing.py +++ b/pytype/abstract/_typing.py @@ -1,9 +1,9 @@ """Constructs related to type annotations.""" -from collections.abc import Mapping +from collections.abc import Iterable, Iterator, Mapping, Sequence, Set import dataclasses import logging -import typing +from typing import Any, Literal, TYPE_CHECKING, cast from pytype import datatypes from pytype.abstract import _base @@ -12,14 +12,27 @@ from pytype.abstract import abstract_utils from pytype.abstract import function from pytype.abstract import mixin +from pytype.pytd import pytd from pytype.pytd import pytd_utils -log = logging.getLogger(__name__) +if TYPE_CHECKING: + from pytype import context # pylint: disable=g-bad-import-order,g-import-not-at-top + from pytype.abstract import _instances # pylint: disable=g-bad-import-order,g-import-not-at-top + from pytype.abstract import _singletons # pylint: disable=g-bad-import-order,g-import-not-at-top + from pytype.typegraph import cfg # pylint: disable=g-bad-import-order,g-import-not-at-top +log: logging.Logger = logging.getLogger(__name__) -def _get_container_type_key(container): + +def _get_container_type_key( + container: ( + _instance_base.SimpleValue | abstract_utils.DummyContainer | None + ), +): try: - return container.get_type_key() + # TODO: b/350643999 - Probably change this to a `None` check rather than + # an exception. + return container.get_type_key() # pytype: disable=attribute-error except AttributeError: return container @@ -27,19 +40,23 @@ def _get_container_type_key(container): class AnnotationClass(_instance_base.SimpleValue, mixin.HasSlots): """Base class of annotations that can be parameterized.""" - def __init__(self, name, ctx): + def __init__(self, name: str, ctx: "context.Context"): super().__init__(name, ctx) mixin.HasSlots.init_mixin(self) self.set_native_slot("__getitem__", self.getitem_slot) - def getitem_slot(self, node, slice_var): + def getitem_slot( + self, node: _base.BaseValue, slice_var: "cfg.Variable" + ) -> "tuple[cfg.CFGNode, cfg.Variable]": """Custom __getitem__ implementation.""" slice_content = abstract_utils.maybe_extract_tuple(slice_var) inner, ellipses = self._build_inner(slice_content) value = self._build_value(node, tuple(inner), ellipses) return node, value.to_variable(node) - def _build_inner(self, slice_content): + def _build_inner( + self, slice_content: "Iterable[cfg.Variable]" + ) -> "tuple[list[_base.BaseValue], set[int]]": """Build the list of parameters. Args: @@ -66,10 +83,15 @@ def _build_inner(self, slice_content): inner.append(val) return inner, ellipses - def _build_value(self, node, inner, ellipses): + def _build_value( + self, + node: _base.BaseValue, + inner: tuple[_base.BaseValue, ...], + ellipses: set[int], + ): raise NotImplementedError(self.__class__.__name__) - def __repr__(self): + def __repr__(self) -> str: return f"AnnotationClass({self.name})" def _get_class(self): @@ -79,11 +101,13 @@ def _get_class(self): class AnnotationContainer(AnnotationClass): """Implementation of X[...] for annotations.""" - def __init__(self, name, ctx, base_cls): + def __init__( + self, name: str, ctx: "context.Context", base_cls: pytd.Class + ) -> None: super().__init__(name, ctx) self.base_cls = base_cls - def __repr__(self): + def __repr__(self) -> str: return f"AnnotationContainer({self.name})" def _sub_annotation( @@ -120,7 +144,10 @@ def _sub_annotation( return annot def _get_value_info( - self, inner, ellipses, allowed_ellipses=frozenset() + self, + inner: tuple[_base.BaseValue, ...], + ellipses: set[int], + allowed_ellipses: Set[int] = frozenset(), ) -> tuple[ tuple[int | str, ...], tuple[_base.BaseValue, ...], @@ -195,7 +222,12 @@ def _get_value_info( abstract_class = _classes.ParameterizedClass return template, inner, abstract_class - def _validate_inner(self, template, inner, raw_inner): + def _validate_inner( + self, + template: tuple[int | str, ...], + inner: tuple[_base.BaseValue, ...], + raw_inner: tuple[_base.BaseValue, ...], + ) -> Sequence[_base.BaseValue]: """Check that the passed inner values are valid for the given template.""" if isinstance( self.base_cls, _classes.ParameterizedClass @@ -257,7 +289,12 @@ def _validate_inner(self, template, inner, raw_inner): ] return inner - def _build_value(self, node, inner, ellipses): + def _build_value( + self, + node: _base.BaseValue, + inner: tuple[_base.BaseValue, ...], + ellipses: set[int], + ) -> "LateAnnotation | _classes.ParameterizedClass | _singletons.Unsolvable": if self.base_cls.is_late_annotation(): # A parameterized LateAnnotation should be converted to another # LateAnnotation to delay evaluation until the first late annotation is @@ -298,7 +335,7 @@ def _build_value(self, node, inner, ellipses): # Protocol[T, ...] is a shorthand for Protocol, Generic[T, ...]. template_params = [ param.with_scope(base_cls.full_name) - for param in typing.cast(tuple[TypeParameter, ...], processed_inner) + for param in cast(tuple[TypeParameter, ...], processed_inner) ] else: template_params = None @@ -374,14 +411,25 @@ def _build_value(self, node, inner, ellipses): self.ctx.errorlog.invalid_annotation(self.ctx.vm.frames, e.annot, e.error) return self.ctx.convert.unsolvable - def call(self, node, func, args, alias_map=None): + def call( + self, + node: "cfg.CFGNode", + func: "cfg.Binding | None", + args: function.Args, + alias_map: datatypes.UnionFind | None = None, + ) -> "tuple[cfg.CFGNode, cfg.Variable]": return self._call_helper(node, self.base_cls, func, args) class _TypeVariableInstance(_base.BaseValue): """An instance of a type parameter.""" - def __init__(self, param, instance, ctx): + def __init__( + self, + param: "_TypeVariable", + instance: _instance_base.Instance, + ctx: "context.Context", + ) -> None: super().__init__(param.name, ctx) self.cls = self.param = param self.instance = instance @@ -391,22 +439,28 @@ def __init__(self, param, instance, ctx): def full_name(self): return f"{self.scope}.{self.name}" if self.scope else self.name - def call(self, node, func, args, alias_map=None): + def call( + self, + node: "cfg.CFGNode", + func: "cfg.Binding | None", + args: function.Args, + alias_map: datatypes.UnionFind | None = None, + ) -> "tuple[cfg.CFGNode, cfg.Variable]": var = self.instance.get_instance_type_parameter(self.name) if var.bindings: return function.call_function(self.ctx, node, var, args) else: return node, self.ctx.convert.empty.to_variable(self.ctx.root_node) - def __eq__(self, other): + def __eq__(self, other: "AnnotationClass") -> bool: if isinstance(other, type(self)): return self.param == other.param and self.instance == other.instance return NotImplemented - def __hash__(self): + def __hash__(self) -> int: return hash((self.param, self.instance)) - def __repr__(self): + def __repr__(self) -> str: return f"{self.__class__.__name__}({self.name!r})" @@ -427,14 +481,15 @@ class _TypeVariable(_base.BaseValue): def __init__( self, - name, - ctx, + name: str, + ctx: "context.Context", + # TODO: b/353979649 - Figure out the type of constraints and scope constraints=(), - bound=None, - covariant=False, - contravariant=False, + bound: _classes.InterpreterClass | None = None, + covariant: bool = False, + contravariant: bool = False, scope=None, - ): + ) -> None: super().__init__(name, ctx) # TODO(b/217789659): PEP-612 does not mention constraints, but ParamSpecs # ignore all the extra parameters anyway.. @@ -445,18 +500,19 @@ def __init__( self.scope = scope @_base.BaseValue.module.setter - def module(self, module): + # TODO: b/353979649 - Figure out the type of module + def module(self, module) -> None: super(_TypeVariable, _TypeVariable).module.fset(self, module) self.scope = module @property - def full_name(self): + def full_name(self) -> str: return f"{self.scope}.{self.name}" if self.scope else self.name - def is_generic(self): + def is_generic(self) -> bool: return not self.constraints and not self.bound - def copy(self): + def copy(self) -> "_TypeVariable": return self.__class__( self.name, self.ctx, @@ -467,12 +523,14 @@ def copy(self): self.scope, ) - def with_scope(self, scope): + def with_scope( + self, scope: Literal["typing.Generic"] | Literal["typing.Protocol"] + ) -> "_TypeVariable": res = self.copy() res.scope = scope return res - def __eq__(self, other): + def __eq__(self, other: "_TypeVariable") -> bool: if isinstance(other, type(self)): return ( self.name == other.name @@ -484,10 +542,10 @@ def __eq__(self, other): ) return NotImplemented - def __ne__(self, other): + def __ne__(self, other: "_TypeVariable") -> bool: return not self == other - def __hash__(self): + def __hash__(self) -> int: return hash(( self.name, self.constraints, @@ -496,7 +554,7 @@ def __hash__(self): self.contravariant, )) - def __repr__(self): + def __repr__(self) -> str: return "{!s}({!r}, constraints={!r}, bound={!r}, module={!r})".format( self.__class__.__name__, self.name, @@ -505,7 +563,11 @@ def __repr__(self): self.scope, ) - def instantiate(self, node, container=None): + def instantiate( + self, + node: "cfg.CFGNode", + container: _instance_base.Instance | None = None, + ) -> "cfg.Variable": var = self.ctx.program.NewVariable() if container and ( not isinstance(container, _instance_base.SimpleValue) @@ -522,7 +584,7 @@ def instantiate(self, node, container=None): var.AddBinding(self.ctx.convert.unsolvable, [], node) return var - def update_official_name(self, name): + def update_official_name(self, name: str) -> None: if self.name != name: message = ( f"TypeVar({self.name!r}) must be stored as {self.name!r}, " @@ -530,68 +592,80 @@ def update_official_name(self, name): ) self.ctx.errorlog.invalid_typevar(self.ctx.vm.frames, message) - def call(self, node, func, args, alias_map=None): + def call( + self, + node: "cfg.CFGNode", + func: "cfg.Binding | None", + args: function.Args, + alias_map: datatypes.UnionFind | None = None, + ) -> "tuple[cfg.CFGNode, cfg.Variable]": return node, self.instantiate(node) class TypeParameter(_TypeVariable): """Parameter of a type (typing.TypeVar).""" - _INSTANCE_CLASS = TypeParameterInstance + _INSTANCE_CLASS: type[TypeParameterInstance] = TypeParameterInstance class ParamSpec(_TypeVariable): """Parameter of a callable type (typing.ParamSpec).""" - _INSTANCE_CLASS = ParamSpecInstance + _INSTANCE_CLASS: type[ParamSpecInstance] = ParamSpecInstance class ParamSpecArgs(_base.BaseValue): """ParamSpec.args.""" - def __init__(self, paramspec, ctx): + def __init__(self, paramspec: ParamSpec, ctx: "context.Context") -> None: super().__init__(f"{paramspec.name}.args", ctx) self.paramspec = paramspec - def instantiate(self, node, container=None): + def instantiate(self, node: "cfg.CFGNode", container=None) -> "cfg.Variable": return self.to_variable(node) class ParamSpecKwargs(_base.BaseValue): """ParamSpec.kwargs.""" - def __init__(self, paramspec, ctx): + def __init__(self, paramspec: ParamSpec, ctx: "context.Context") -> None: super().__init__(f"{paramspec.name}.kwargs", ctx) self.paramspec = paramspec - def instantiate(self, node, container=None): + def instantiate(self, node: "cfg.CFGNode", container=None) -> "cfg.Variable": return self.to_variable(node) class Concatenate(_base.BaseValue): """Concatenation of args and ParamSpec.""" - def __init__(self, params, ctx): + def __init__(self, params: list[ParamSpec], ctx: "context.Context") -> None: super().__init__("Concatenate", ctx) self.args = params[:-1] self.paramspec = params[-1] @property - def full_name(self): + def full_name(self) -> str: return self.paramspec.full_name - def instantiate(self, node, container=None): + def instantiate( + self, + node: "cfg.CFGNode", + container: ( + _instance_base.SimpleValue | abstract_utils.DummyContainer | None + ) = None, + ): return self.to_variable(node) @property - def num_args(self): + def num_args(self) -> int: return len(self.args) - def get_args(self): + def get_args(self) -> list[ParamSpec]: # Satisfies the same interface as abstract.CallableClass return self.args - def __repr__(self): + def __repr__(self) -> str: args = ", ".join(list(map(repr, self.args)) + [self.paramspec.name]) return f"Concatenate[{args}]" @@ -605,7 +679,13 @@ class Union(_base.BaseValue, mixin.NestedAnnotation, mixin.HasSlots): options: Iterable of instances of BaseValue. """ - def __init__(self, options, ctx): + def __init__( + # TODO: b/353979649 - Rename `options` to something else. + # Possibly `disjuncts`. + self, + options: Iterable[_base.BaseValue], + ctx: "context.Context", + ) -> None: super().__init__("Union", ctx) assert options self.options = list(options) @@ -616,7 +696,7 @@ def __init__(self, options, ctx): mixin.HasSlots.init_mixin(self) self.set_native_slot("__getitem__", self.getitem_slot) - def __repr__(self): + def __repr__(self) -> str: if self._printing: # recursion detected printed_contents = "..." else: @@ -625,15 +705,15 @@ def __repr__(self): self._printing = False return f"{self.name}[{printed_contents}]" - def __eq__(self, other): + def __eq__(self, other: "Union") -> bool: if isinstance(other, type(self)): return self.options == other.options return NotImplemented - def __ne__(self, other): + def __ne__(self, other: "Union") -> bool: return not self == other - def __hash__(self): + def __hash__(self) -> int: # Use the names of the parameter values to approximate a hash, to avoid # infinite recursion on recursive type annotations. return hash(tuple(o.full_name for o in self.options)) @@ -641,14 +721,16 @@ def __hash__(self): def _unique_parameters(self): return [o.to_variable(self.ctx.root_node) for o in self.options] - def _get_class(self): + def _get_class(self) -> _base.BaseValue: classes = {o.cls for o in self.options} if len(classes) > 1: return self.ctx.convert.unsolvable else: return classes.pop() - def getitem_slot(self, node, slice_var): + def getitem_slot( + self, node: "cfg.CFGNode", slice_var + ) -> "tuple[cfg.CFGNode, cfg.Variable]": """Custom __getitem__ implementation.""" slice_content = abstract_utils.maybe_extract_tuple(slice_var) params = self.ctx.annotation_utils.get_type_parameters(self) @@ -677,7 +759,13 @@ def getitem_slot(self, node, slice_var): new = self.ctx.annotation_utils.sub_one_annotation(node, self, [subst]) return node, new.to_variable(node) - def instantiate(self, node, container=None): + def instantiate( + self, + node: "cfg.CFGNode", + container: ( + _instance_base.SimpleValue | abstract_utils.DummyContainer | None + ) = None, + ): var = self.ctx.program.NewVariable() for option in self.options: k = (node, _get_container_type_key(container), option) @@ -692,7 +780,13 @@ def instantiate(self, node, container=None): var.PasteVariable(instance, node) return var - def call(self, node, func, args, alias_map=None): + def call( + self, + node: "cfg.CFGNode", + func: "cfg.Binding", + args: function.Args, + alias_map: datatypes.UnionFind | None = None, + ) -> "tuple[cfg.CFGNode, cfg.Variable]": var = self.ctx.program.NewVariable(self.options, [], node) return function.call_function(self.ctx, node, var, args) @@ -702,13 +796,15 @@ def get_formal_type_parameter(self, t): ] return Union(new_options, self.ctx) - def get_inner_types(self): + def get_inner_types(self) -> Iterator[tuple[int, _base.BaseValue]]: return enumerate(self.options) - def update_inner_type(self, key, typ): + def update_inner_type(self, key: int, typ: _base.BaseValue) -> None: self.options[key] = typ - def replace(self, inner_types): + def replace( + self, inner_types: Sequence[tuple[int, _base.BaseValue]] + ) -> "Union": return self.__class__((v for _, v in sorted(inner_types)), self.ctx) @@ -726,16 +822,20 @@ class LateAnnotation: Use `x.is_late_annotation()` to check whether x is a late annotation. """ - _RESOLVING = object() + _RESOLVING: Any = object() - def __init__(self, expr, stack, ctx, *, typing_imports=None): + def __init__( + self, expr, stack, ctx: "context.Context", *, typing_imports=None + ): self.expr = expr self.stack = stack self.ctx = ctx self.resolved = False # Any new typing imports the annotation needs while resolving. self._typing_imports = typing_imports or set() - self._type = ctx.convert.unsolvable # the resolved type of `expr` + self._type: _base.BaseValue = ( + ctx.convert.unsolvable + ) # the resolved type of `expr` self._unresolved_instances = set() self._resolved_instances = {} # _attribute_names needs to be defined last! This contains the names of all @@ -784,7 +884,7 @@ def unflatten_expr(self): ) return self.expr - def __repr__(self): + def __repr__(self) -> str: return "LateAnnotation({!r}, resolved={!r})".format( self.expr, self._type if self.resolved else None ) @@ -792,10 +892,10 @@ def __repr__(self): # __hash__ and __eq__ need to be explicitly defined for Python to use them in # set/dict comparisons. - def __hash__(self): + def __hash__(self) -> int: return hash(self._type) if self.resolved else hash(self.expr) - def __eq__(self, other): + def __eq__(self, other) -> bool: return hash(self) == hash(other) def __getattribute__(self, name): @@ -814,7 +914,12 @@ def __setattr__(self, name, value): def __contains__(self, name): return self.resolved and name in self._type - def resolve(self, node, f_globals, f_locals): + def resolve( + self, + node: "cfg.CFGNode", + f_globals: "_instances.LazyConcreteDict", + f_locals: "_instances.LazyConcreteDict", + ) -> None: """Resolve the late annotation.""" if self.resolved: return @@ -851,20 +956,26 @@ def resolve(self, node, f_globals, f_locals): self.resolved = True log.info("Resolved late annotation %r to %r", self.expr, self._type) - def set_type(self, typ): + def set_type(self, typ: _base.BaseValue) -> None: # Used by annotation_utils.sub_one_annotation to substitute values into # recursive aliases. assert not self.resolved self.resolved = True self._type = typ - def to_variable(self, node): + def to_variable(self, node: "cfg.CFGNode") -> "cfg.Variable": if self.resolved: return self._type.to_variable(node) else: return _base.BaseValue.to_variable(self, node) # pytype: disable=wrong-arg-types - def instantiate(self, node, container=None): + def instantiate( + self, + node: "cfg.CFGNode", + container: ( + _instance_base.SimpleValue | abstract_utils.DummyContainer | None + ) = None, + ) -> "cfg.Variable": """Instantiate the pointed-to class, or record a placeholder instance.""" if self.resolved: key = (node, _get_container_type_key(container)) @@ -876,16 +987,18 @@ def instantiate(self, node, container=None): self._unresolved_instances.add(instance) return instance.to_variable(node) - def get_special_attribute(self, node, name, valself): + def get_special_attribute( + self, node: "cfg.CFGNode", name: str, valself: "cfg.Variable" + ) -> "cfg.Variable | None": if name == "__getitem__" and not self.resolved: container = _base.BaseValue.to_annotation_container(self) # pytype: disable=wrong-arg-types return container.get_special_attribute(node, name, valself) return self._type.get_special_attribute(node, name, valself) - def is_late_annotation(self): + def is_late_annotation(self) -> bool: return True - def is_recursive(self): + def is_recursive(self) -> bool: """Check whether this is a recursive type.""" if not self.resolved: return False @@ -905,12 +1018,18 @@ def is_recursive(self): class FinalAnnotation(_base.BaseValue): """Container for a Final annotation.""" - def __init__(self, annotation, ctx): + def __init__(self, annotation: _base.BaseValue, ctx: "context.Context"): super().__init__("FinalAnnotation", ctx) self.annotation = annotation - def __repr__(self): + def __repr__(self) -> str: return f"Final[{self.annotation}]" - def instantiate(self, node, container=None): + def instantiate( + self, + node: "cfg.CFGNode", + container: ( + _instance_base.SimpleValue | abstract_utils.DummyContainer | None + ) = None, + ) -> "cfg.Variable": return self.to_variable(node) diff --git a/pytype/overlays/typing_overlay.py b/pytype/overlays/typing_overlay.py index 7e725a1b7..e6c43b7a6 100644 --- a/pytype/overlays/typing_overlay.py +++ b/pytype/overlays/typing_overlay.py @@ -184,7 +184,9 @@ def getitem_slot(self, node, slice_var): inner, ellipses = self._build_inner(content) args = inner[0] if abstract_utils.is_concrete_list(args): - inner[0], inner_ellipses = self._build_inner(args.pyval) + # No attribute 'pyval' on pytype.abstract._base.BaseValue + args_pyval = args.pyval # pytype: disable=attribute-error + inner[0], inner_ellipses = self._build_inner(args_pyval) self.ctx.errorlog.invalid_ellipses( self.ctx.vm.frames, inner_ellipses, args.name )