Skip to content

Commit

Permalink
Annotate type on abstract/_instance.base.py
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 689299928
  • Loading branch information
h-joo authored and copybara-github committed Oct 24, 2024
1 parent d6e4cde commit 15cc491
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 34 deletions.
2 changes: 1 addition & 1 deletion pytype/abstract/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def to_pytd_def(self, node, name):
"""Get a PyTD definition for this object."""
return self.ctx.pytd_convert.value_to_pytd_def(node, self, name)

def get_default_type_key(self) -> "type[BaseValue]":
def get_default_type_key(self) -> "type[BaseValue] | frozenset":
"""Gets a default type key. See get_type_key."""
return type(self)

Expand Down
4 changes: 2 additions & 2 deletions pytype/abstract/_function_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,10 +171,10 @@ def argcount(self, _: "cfg.CFGNode") -> int:
def call(
self,
node: "cfg.CFGNode",
func: Function,
func: "cfg.Binding",
args: function.Args,
alias_map: "datatypes.UnionFind | None" = None,
):
) -> "tuple[cfg.CFGNode, cfg.Variable]":
sig = None
if isinstance(
self.func.__self__, # pytype: disable=attribute-error
Expand Down
103 changes: 73 additions & 30 deletions pytype/abstract/_instance_base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Abstract representation of instances."""

import logging
from typing import TYPE_CHECKING

from pytype import datatypes
from pytype.abstract import _base
Expand All @@ -9,9 +10,14 @@
from pytype.abstract import function
from pytype.errors import error_types

log = logging.getLogger(__name__)
log: logging.Logger = logging.getLogger(__name__)
_isinstance = abstract_utils._isinstance # pylint: disable=protected-access

if TYPE_CHECKING:
from pytype import context # pylint: disable=g-bad-import-order,g-import-not-at-top
from pytype.abstract import _typing # pylint: disable=g-bad-import-order,g-import-not-at-top
from pytype.typegraph import cfg # pylint: disable=g-bad-import-order,g-import-not-at-top


class SimpleValue(_base.BaseValue):
"""A basic abstract value that represents instances.
Expand All @@ -26,7 +32,7 @@ class may vary.
members: A name->value dictionary of the instance's attributes.
"""

def __init__(self, name, ctx):
def __init__(self, name: str, ctx: "context.Context"):
"""Initialize a SimpleValue.
Args:
Expand All @@ -38,47 +44,53 @@ def __init__(self, name, ctx):
self.members = datatypes.MonitorDict()
# Lazily loaded to handle recursive types.
# See Instance._load_instance_type_parameters().
self._instance_type_parameters = datatypes.AliasingMonitorDict()
self._instance_type_parameters: (
"datatypes.AliasingMonitorDict[str, cfg.Variable]"
) = datatypes.AliasingMonitorDict()
# This attribute depends on self.cls, which isn't yet set to its true value.
self._maybe_missing_members = None
self._maybe_missing_members: bool | None = None
# The latter caches the result of get_type_key. This is a recursive function
# that has the potential to generate too many calls for large definitions.
self._type_key = None
self._type_key: "frozenset[_base.BaseValue | _typing.LateAnnotation | tuple[str, frozenset]] | None" = (None)
self._fullhash = None
self._cached_changestamps = self._get_changestamps()

def _get_changestamps(self):
def _get_changestamps(self) -> "tuple[int, int]":
return (
self.members.changestamp,
self._instance_type_parameters.changestamp,
)

@property
def instance_type_parameters(self):
def instance_type_parameters(
self,
) -> "datatypes.AliasingMonitorDict[str, cfg.Variable]":
return self._instance_type_parameters

@property
def maybe_missing_members(self):
def maybe_missing_members(self) -> bool:
if self._maybe_missing_members is None:
# maybe_missing_members indicates that every attribute access on this
# object should always succeed. This is usually indicated by the class
# setting _HAS_DYNAMIC_ATTRIBUTES = True.
# This should apply to both the class and instances of the class.
dyn_self = isinstance(self, class_mixin.Class) and self.is_dynamic
dyn_cls = isinstance(self.cls, class_mixin.Class) and self.cls.is_dynamic
self._maybe_missing_members = dyn_self or dyn_cls
self._maybe_missing_members = bool(dyn_self or dyn_cls)
return self._maybe_missing_members

@maybe_missing_members.setter
def maybe_missing_members(self, v):
def maybe_missing_members(self, v: bool) -> None:
self._maybe_missing_members = v

def has_instance_type_parameter(self, name):
def has_instance_type_parameter(self, name: str) -> bool:
"""Check if the key is in `instance_type_parameters`."""
name = abstract_utils.full_type_name(self, name)
return name in self.instance_type_parameters

def get_instance_type_parameter(self, name, node=None):
def get_instance_type_parameter(
self, name: str, node: "cfg.CFGNode | None" = None
) -> "cfg.Variable":
name = abstract_utils.full_type_name(self, name)
param = self.instance_type_parameters.get(name)
if not param:
Expand All @@ -89,7 +101,9 @@ def get_instance_type_parameter(self, name, node=None):
self.instance_type_parameters[name] = param
return param

def merge_instance_type_parameter(self, node, name, value):
def merge_instance_type_parameter(
self, node: "cfg.CFGNode|None", name: str, value: "cfg.Variable"
) -> None:
"""Set the value of a type parameter.
This will always add to the type parameter unlike set_attribute which will
Expand All @@ -109,7 +123,13 @@ def merge_instance_type_parameter(self, node, name, value):
else:
self.instance_type_parameters[name] = value

def _call_helper(self, node, obj, binding, args):
def _call_helper(
self,
node: "cfg.CFGNode",
obj,
binding: "cfg.Binding",
args: function.Args,
) -> "tuple[cfg.CFGNode, cfg.Variable]":
obj_binding = binding if obj == binding.data else obj.to_binding(node)
node, var = self.ctx.attribute_handler.get_attribute(
node, obj, "__call__", obj_binding
Expand All @@ -119,10 +139,16 @@ def _call_helper(self, node, obj, binding, args):
else:
raise error_types.NotCallable(self)

def call(self, node, func, args, alias_map=None):
def call(
self,
node: "cfg.CFGNode",
func: "cfg.Binding",
args: function.Args,
alias_map: datatypes.UnionFind | None = None,
) -> "tuple[cfg.CFGNode, cfg.Variable]":
return self._call_helper(node, self, func, args)

def argcount(self, node):
def argcount(self, node: "cfg.CFGNode") -> int:
node, var = self.ctx.attribute_handler.get_attribute(
node, self, "__call__", self.to_binding(node)
)
Expand All @@ -133,7 +159,7 @@ def argcount(self, node):
# value will lead to a not-callable error anyways.
return 0

def __repr__(self):
def __repr__(self) -> str:
return f"<{self.name} [{self.cls!r}]>"

def _get_class(self):
Expand All @@ -152,7 +178,9 @@ def cls(self):
def cls(self, cls):
self._cls = cls

def set_class(self, node, var):
def set_class(
self, node: "cfg.CFGNode", var: "cfg.Variable"
) -> "cfg.CFGNode":
"""Set the __class__ of an instance, for code that does "x.__class__ = y."""
# Simplification: Setting __class__ is done rarely, and supporting this
# action would complicate pytype considerably by forcing us to track the
Expand All @@ -166,15 +194,15 @@ def set_class(self, node, var):
self.cls = self.ctx.convert.unsolvable
return node

def update_caches(self, force=False):
def update_caches(self, force: bool = False) -> None:
cur_changestamps = self._get_changestamps()
if self._cached_changestamps == cur_changestamps and not force:
return
self._fullhash = None
self._type_key = None
self._cached_changestamps = cur_changestamps

def get_fullhash(self, seen=None):
def get_fullhash(self, seen: set[int] | None = None) -> int:
self.update_caches()
if not self._fullhash:
if seen is None:
Expand All @@ -190,7 +218,9 @@ def get_fullhash(self, seen=None):
self._fullhash = hash(tuple(components))
return self._fullhash

def get_type_key(self, seen=None):
def get_type_key(
self, seen: set[_base.BaseValue] | None = None
) -> "frozenset[_base.BaseValue | _typing.LateAnnotation | tuple[str, frozenset]] | type[_base.BaseValue]":
self.update_caches()
if not self._type_key:
if seen is None:
Expand All @@ -205,33 +235,42 @@ def get_type_key(self, seen=None):
self._type_key = frozenset(key)
return self._type_key

def _unique_parameters(self):
def _unique_parameters(self) -> "list[cfg.Variable]":
parameters = super()._unique_parameters()
parameters.extend(self.instance_type_parameters.values())
return parameters

def instantiate(self, node, container=None):
def instantiate(self, node: "cfg.CFGNode", container=None) -> "cfg.Variable":
return Instance(self, self.ctx, container).to_variable(node)


class Instance(SimpleValue):
"""An instance of some object."""

def __init__(self, cls, ctx, container=None):
def __init__(
self,
cls: "_base.BaseValue | _typing.LateAnnotation",
ctx: "context.Context",
container=None,
) -> None:
super().__init__(cls.name, ctx)
self.cls = cls
self._instance_type_parameters_loaded = False
self._container = container
cls.register_instance(self)

def _load_instance_type_parameters(self):
def _load_instance_type_parameters(self) -> None:
if self._instance_type_parameters_loaded:
return
all_formal_type_parameters = datatypes.AliasingDict()
all_formal_type_parameters: "datatypes.AliasingDict[str, SimpleValue]" = (
datatypes.AliasingDict()
)
abstract_utils.parse_formal_type_parameters(
self.cls, None, all_formal_type_parameters, self._container
)
self._instance_type_parameters = self._instance_type_parameters.copy(
self._instance_type_parameters: (
"datatypes.AliasingDict[str, cfg.Variable]"
) = self._instance_type_parameters.copy(
aliases=all_formal_type_parameters.aliases
)
for name, param in all_formal_type_parameters.items():
Expand All @@ -249,15 +288,19 @@ def _load_instance_type_parameters(self):
self._instance_type_parameters_loaded = True

@property
def full_name(self):
def full_name(self) -> str:
return self.cls.full_name

@property
def instance_type_parameters(self):
def instance_type_parameters(
self,
) -> "datatypes.AliasingDict[str, cfg.Variable]":
self._load_instance_type_parameters()
return self._instance_type_parameters

def get_type_key(self, seen=None):
def get_type_key(
self, seen: set[_base.BaseValue] | None = None
) -> "frozenset[_base.BaseValue | _typing.LateAnnotation | tuple[str, frozenset]] | type[_base.BaseValue|_typing.LateAnnotation]":
if not self._type_key and not self._instance_type_parameters_loaded:
# If we might be the middle of loading this class, don't try to access
# instance_type_parameters. We don't cache this intermediate type key
Expand Down
2 changes: 1 addition & 1 deletion pytype/matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -1272,7 +1272,7 @@ def _match_instance_against_type(self, left, other_type, subst, view):
return subst if left.pyval == other_value.pyval else None
elif (
isinstance(left, abstract.Instance)
and left.cls.is_enum
and left.cls.is_enum # pytype: disable=attribute-error
and isinstance(other_value, abstract.Instance)
and other_value.cls.is_enum
):
Expand Down

0 comments on commit 15cc491

Please sign in to comment.