Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Annotate type on abstract/_special_classes.py #1825

Merged
merged 1 commit into from
Oct 28, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 69 additions & 27 deletions pytype/abstract/_special_classes.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,25 @@
"""Classes that need special handling, typically due to code generation."""

from collections.abc import Sequence
from typing import TYPE_CHECKING
from pytype.abstract import abstract_utils
from pytype.abstract import class_mixin
from pytype.pytd import pytd

if TYPE_CHECKING:
from pytype import context # pylint: disable=g-bad-import-order,g-import-not-at-top
from pytype.typegraph import cfg # pylint: disable=g-bad-import-order,g-import-not-at-top
from pytype.abstract import _classes # pylint: disable=g-bad-import-order,g-import-not-at-top
from pytype.overlays import named_tuple # pylint: disable=g-bad-import-order,g-import-not-at-top
from pytype.overlays import typed_dict # pylint: disable=g-bad-import-order,g-import-not-at-top

def build_class(node, props, kwargs, ctx):

def build_class(
node: "cfg.CFGNode",
props: class_mixin.ClassBuilderProperties,
kwargs: "dict[str, cfg.Variable]",
ctx: "context.Context",
) -> "tuple[cfg.CFGNode, cfg.Variable | None]":
"""Handle classes whose subclasses define their own class constructors."""

for base in props.bases:
Expand All @@ -14,7 +28,9 @@ def build_class(node, props, kwargs, ctx):
continue
if base.is_enum:
enum_base = abstract_utils.get_atomic_value(
ctx.vm.loaded_overlays["enum"].members["Enum"]
ctx.vm.loaded_overlays["enum"].members[
"Enum"
] # pytype: disable=attribute-error
)
return enum_base.make_class(node, props)
elif base.full_name == "typing.NamedTuple":
Expand All @@ -32,34 +48,44 @@ def build_class(node, props, kwargs, ctx):
class _Builder:
"""Build special classes created by inheriting from a specific class."""

def __init__(self, ctx):
def __init__(self, ctx: "context.Context"):
self.ctx = ctx
self.convert = ctx.convert

def matches_class(self, c):
def matches_class(self, c: "_classes.PyTDClass"):
raise NotImplementedError()

def matches_base(self, c):
def matches_base(self, c: "_classes.PyTDClass"):
raise NotImplementedError()

def matches_mro(self, c):
def matches_mro(self, c: "_classes.PyTDClass"):
raise NotImplementedError()

def make_base_class(self):
def make_base_class(
self,
) -> (
"typed_dict.TypedDictBuilder | named_tuple.NamedTupleClassBuilder | None"
):
raise NotImplementedError()

def make_derived_class(self, name, pytd_cls):
def make_derived_class(
self, name: str, pytd_cls: "_classes.PyTDClass"
) -> "typed_dict.TypedDictClass | cfg.Variable | None":
raise NotImplementedError()

def maybe_build_from_pytd(self, name, pytd_cls):
def maybe_build_from_pytd(
self, name: str, pytd_cls: pytd.Class
) -> "typed_dict.TypedDictBuilder | named_tuple.NamedTupleClassBuilder | typed_dict.TypedDictClass | cfg.Variable | None":
if self.matches_class(pytd_cls):
return self.make_base_class()
elif self.matches_base(pytd_cls):
return self.make_derived_class(name, pytd_cls)
else:
return None

def maybe_build_from_mro(self, abstract_cls, name, pytd_cls):
def maybe_build_from_mro(
self, abstract_cls: "_classes.PyTDClass", name: str, pytd_cls: pytd.Class
) -> "typed_dict.TypedDictClass | cfg.Variable | None":
if self.matches_mro(abstract_cls):
return self.make_derived_class(name, pytd_cls)
return None
Expand All @@ -68,59 +94,70 @@ def maybe_build_from_mro(self, abstract_cls, name, pytd_cls):
class _TypedDictBuilder(_Builder):
"""Build a typed dict."""

CLASSES = ("typing.TypedDict", "typing_extensions.TypedDict")
# TODO: b/350643999 - Should rather be a ClassVar[Sequence[str]]
CLASSES: Sequence[str] = ("typing.TypedDict", "typing_extensions.TypedDict")

def matches_class(self, c):
def matches_class(self, c: "_classes.PyTDClass") -> bool:
return c.name in self.CLASSES

def matches_base(self, c):
return any(
def matches_base(self, c: "_classes.PyTDClass") -> bool:
return any( # pytype: disable=attribute-error
isinstance(b, pytd.ClassType) and self.matches_class(b) for b in c.bases
)

def matches_mro(self, c):
def matches_mro(self, c: "_classes.PyTDClass") -> bool:
# Check if we have typed dicts in the MRO by seeing if we have already
# created a TypedDictClass for one of the ancestor classes.
return any(
isinstance(b, class_mixin.Class) and b.is_typed_dict_class
for b in c.mro
)

def make_base_class(self):
def make_base_class(self) -> "typed_dict.TypedDictBuilder":
return self.convert.make_typed_dict_builder()

def make_derived_class(self, name, pytd_cls):
def make_derived_class(
self, name: str, pytd_cls: "_classes.PyTDClass"
) -> "typed_dict.TypedDictClass":
return self.convert.make_typed_dict(name, pytd_cls)


class _NamedTupleBuilder(_Builder):
"""Build a namedtuple."""

CLASSES = ("typing.NamedTuple",)
# TODO: b/350643999 - Should rather be a ClassVar[Sequence[str]]
CLASSES: Sequence[str] = ("typing.NamedTuple",)

def matches_class(self, c):
def matches_class(self, c: "_classes.PyTDClass") -> bool:
return c.name in self.CLASSES

def matches_base(self, c):
return any(
def matches_base(self, c: "_classes.PyTDClass") -> bool:
return any( # pytype: disable=attribute-error
isinstance(b, pytd.ClassType) and self.matches_class(b) for b in c.bases
)

def matches_mro(self, c):
def matches_mro(self, c: "_classes.PyTDClass") -> bool:
# We only create namedtuples by direct inheritance
return False

def make_base_class(self):
def make_base_class(self) -> "named_tuple.NamedTupleClassBuilder":
return self.convert.make_namedtuple_builder()

def make_derived_class(self, name, pytd_cls):
def make_derived_class(
self, name: str, pytd_cls: "_classes.PyTDClass"
) -> "cfg.Variable":
return self.convert.make_namedtuple(name, pytd_cls)


_BUILDERS = (_TypedDictBuilder, _NamedTupleBuilder)
_BUILDERS: Sequence[type[_Builder]] = (
_TypedDictBuilder,
_NamedTupleBuilder,
)


def maybe_build_from_pytd(name, pytd_cls, ctx):
def maybe_build_from_pytd(
name: str, pytd_cls: pytd.Class, ctx: "context.Context"
):
"""Try to build a special class from a pytd class."""
for b in _BUILDERS:
ret = b(ctx).maybe_build_from_pytd(name, pytd_cls)
Expand All @@ -129,7 +166,12 @@ def maybe_build_from_pytd(name, pytd_cls, ctx):
return None


def maybe_build_from_mro(abstract_cls, name, pytd_cls, ctx):
def maybe_build_from_mro(
abstract_cls: "_classes.PyTDClass",
name: str,
pytd_cls: pytd.Class,
ctx: "context.Context",
) -> "typed_dict.TypedDictClass | cfg.Variable | None":
"""Try to build a special class from the MRO of an abstract class."""
for b in _BUILDERS:
ret = b(ctx).maybe_build_from_mro(abstract_cls, name, pytd_cls)
Expand Down
Loading