Skip to content

Commit

Permalink
Annotate type on abstract/_special_classes.py
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 690620520
  • Loading branch information
h-joo authored and copybara-github committed Oct 28, 2024
1 parent 8fff648 commit 00573d1
Showing 1 changed file with 69 additions and 27 deletions.
96 changes: 69 additions & 27 deletions pytype/abstract/_special_classes.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,25 @@
"""Classes that need special handling, typically due to code generation."""

from collections.abc import Sequence
from typing import TYPE_CHECKING
from pytype.abstract import abstract_utils
from pytype.abstract import class_mixin
from pytype.pytd import pytd

if TYPE_CHECKING:
from pytype import context # pylint: disable=g-bad-import-order,g-import-not-at-top
from pytype.typegraph import cfg # pylint: disable=g-bad-import-order,g-import-not-at-top
from pytype.abstract import _classes # pylint: disable=g-bad-import-order,g-import-not-at-top
from pytype.overlays import named_tuple # pylint: disable=g-bad-import-order,g-import-not-at-top
from pytype.overlays import typed_dict # pylint: disable=g-bad-import-order,g-import-not-at-top

def build_class(node, props, kwargs, ctx):

def build_class(
node: "cfg.CFGNode",
props: class_mixin.ClassBuilderProperties,
kwargs: "dict[str, cfg.Variable]",
ctx: "context.Context",
) -> "tuple[cfg.CFGNode, cfg.Variable | None]":
"""Handle classes whose subclasses define their own class constructors."""

for base in props.bases:
Expand All @@ -14,7 +28,9 @@ def build_class(node, props, kwargs, ctx):
continue
if base.is_enum:
enum_base = abstract_utils.get_atomic_value(
ctx.vm.loaded_overlays["enum"].members["Enum"]
ctx.vm.loaded_overlays["enum"].members[
"Enum"
] # pytype: disable=attribute-error
)
return enum_base.make_class(node, props)
elif base.full_name == "typing.NamedTuple":
Expand All @@ -32,34 +48,44 @@ def build_class(node, props, kwargs, ctx):
class _Builder:
"""Build special classes created by inheriting from a specific class."""

def __init__(self, ctx):
def __init__(self, ctx: "context.Context"):
self.ctx = ctx
self.convert = ctx.convert

def matches_class(self, c):
def matches_class(self, c: "_classes.PyTDClass"):
raise NotImplementedError()

def matches_base(self, c):
def matches_base(self, c: "_classes.PyTDClass"):
raise NotImplementedError()

def matches_mro(self, c):
def matches_mro(self, c: "_classes.PyTDClass"):
raise NotImplementedError()

def make_base_class(self):
def make_base_class(
self,
) -> (
"typed_dict.TypedDictBuilder | named_tuple.NamedTupleClassBuilder | None"
):
raise NotImplementedError()

def make_derived_class(self, name, pytd_cls):
def make_derived_class(
self, name: str, pytd_cls: "_classes.PyTDClass"
) -> "typed_dict.TypedDictClass | cfg.Variable | None":
raise NotImplementedError()

def maybe_build_from_pytd(self, name, pytd_cls):
def maybe_build_from_pytd(
self, name: str, pytd_cls: pytd.Class
) -> "typed_dict.TypedDictBuilder | named_tuple.NamedTupleClassBuilder | typed_dict.TypedDictClass | cfg.Variable | None":
if self.matches_class(pytd_cls):
return self.make_base_class()
elif self.matches_base(pytd_cls):
return self.make_derived_class(name, pytd_cls)
else:
return None

def maybe_build_from_mro(self, abstract_cls, name, pytd_cls):
def maybe_build_from_mro(
self, abstract_cls: "_classes.PyTDClass", name: str, pytd_cls: pytd.Class
) -> "typed_dict.TypedDictClass | cfg.Variable | None":
if self.matches_mro(abstract_cls):
return self.make_derived_class(name, pytd_cls)
return None
Expand All @@ -68,59 +94,70 @@ def maybe_build_from_mro(self, abstract_cls, name, pytd_cls):
class _TypedDictBuilder(_Builder):
"""Build a typed dict."""

CLASSES = ("typing.TypedDict", "typing_extensions.TypedDict")
# TODO: b/350643999 - Should rather be a ClassVar[Sequence[str]]
CLASSES: Sequence[str] = ("typing.TypedDict", "typing_extensions.TypedDict")

def matches_class(self, c):
def matches_class(self, c: "_classes.PyTDClass") -> bool:
return c.name in self.CLASSES

def matches_base(self, c):
return any(
def matches_base(self, c: "_classes.PyTDClass") -> bool:
return any( # pytype: disable=attribute-error
isinstance(b, pytd.ClassType) and self.matches_class(b) for b in c.bases
)

def matches_mro(self, c):
def matches_mro(self, c: "_classes.PyTDClass") -> bool:
# Check if we have typed dicts in the MRO by seeing if we have already
# created a TypedDictClass for one of the ancestor classes.
return any(
isinstance(b, class_mixin.Class) and b.is_typed_dict_class
for b in c.mro
)

def make_base_class(self):
def make_base_class(self) -> "typed_dict.TypedDictBuilder":
return self.convert.make_typed_dict_builder()

def make_derived_class(self, name, pytd_cls):
def make_derived_class(
self, name: str, pytd_cls: "_classes.PyTDClass"
) -> "typed_dict.TypedDictClass":
return self.convert.make_typed_dict(name, pytd_cls)


class _NamedTupleBuilder(_Builder):
"""Build a namedtuple."""

CLASSES = ("typing.NamedTuple",)
# TODO: b/350643999 - Should rather be a ClassVar[Sequence[str]]
CLASSES: Sequence[str] = ("typing.NamedTuple",)

def matches_class(self, c):
def matches_class(self, c: "_classes.PyTDClass") -> bool:
return c.name in self.CLASSES

def matches_base(self, c):
return any(
def matches_base(self, c: "_classes.PyTDClass") -> bool:
return any( # pytype: disable=attribute-error
isinstance(b, pytd.ClassType) and self.matches_class(b) for b in c.bases
)

def matches_mro(self, c):
def matches_mro(self, c: "_classes.PyTDClass") -> bool:
# We only create namedtuples by direct inheritance
return False

def make_base_class(self):
def make_base_class(self) -> "named_tuple.NamedTupleClassBuilder":
return self.convert.make_namedtuple_builder()

def make_derived_class(self, name, pytd_cls):
def make_derived_class(
self, name: str, pytd_cls: "_classes.PyTDClass"
) -> "cfg.Variable":
return self.convert.make_namedtuple(name, pytd_cls)


_BUILDERS = (_TypedDictBuilder, _NamedTupleBuilder)
_BUILDERS: Sequence[type[_Builder]] = (
_TypedDictBuilder,
_NamedTupleBuilder,
)


def maybe_build_from_pytd(name, pytd_cls, ctx):
def maybe_build_from_pytd(
name: str, pytd_cls: pytd.Class, ctx: "context.Context"
):
"""Try to build a special class from a pytd class."""
for b in _BUILDERS:
ret = b(ctx).maybe_build_from_pytd(name, pytd_cls)
Expand All @@ -129,7 +166,12 @@ def maybe_build_from_pytd(name, pytd_cls, ctx):
return None


def maybe_build_from_mro(abstract_cls, name, pytd_cls, ctx):
def maybe_build_from_mro(
abstract_cls: "_classes.PyTDClass",
name: str,
pytd_cls: pytd.Class,
ctx: "context.Context",
) -> "typed_dict.TypedDictClass | cfg.Variable | None":
"""Try to build a special class from the MRO of an abstract class."""
for b in _BUILDERS:
ret = b(ctx).maybe_build_from_mro(abstract_cls, name, pytd_cls)
Expand Down

0 comments on commit 00573d1

Please sign in to comment.