Skip to content

Commit

Permalink
Merge pull request #1568 from google/google_sync
Browse files Browse the repository at this point in the history
Google sync
  • Loading branch information
rchen152 authored Jan 23, 2024
2 parents 2e30cee + bb0736a commit c6869a1
Show file tree
Hide file tree
Showing 9 changed files with 124 additions and 46 deletions.
19 changes: 12 additions & 7 deletions pytype/matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,12 +456,17 @@ def match_var_against_type(self, var, other_type, subst, view):

def _match_type_param_against_type_param(self, t1, t2, subst, view):
"""Match a TypeVar against another TypeVar."""
if t1.full_name == "typing.Self" and not t1.bound:
# We're matching a Self instance before it's been bound to its containing
# class. We know it should be bound but not to what, so `Any` is the best
# we can do.
t1 = t1.copy()
t1.bound = self.ctx.convert.unsolvable
if t1.full_name == "typing.Self":
if t2.full_name == "typing.Self":
# Self always matches itself. We check for this explicitly because Self
# instances may have their bounds set to incompatible classes.
return subst
elif not t1.bound:
# We're matching a Self instance before it's been bound to its
# containing class. We know it should be bound but not to what, so `Any`
# is the best we can do.
t1 = t1.copy()
t1.bound = self.ctx.convert.unsolvable
if t2.constraints:
assert not t2.bound # constraints and bounds are mutually exclusive
# We only check the constraints for t1, not the bound. We wouldn't know
Expand Down Expand Up @@ -1461,7 +1466,7 @@ def _match_dict_against_typed_dict(
for k, v in left.pyval.items():
if k not in fields:
continue
typ = abstract_utils.get_atomic_value(fields[k])
typ = fields[k]
match_result = self.compute_one_match(v, typ)
if not match_result.success:
bad.append((k, match_result.bad_matches))
Expand Down
6 changes: 2 additions & 4 deletions pytype/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -1016,10 +1016,8 @@ def _typed_dict_to_def(self, node, v, name):
keywords.append(("total", pytd.Literal(False)))
bases = (pytd.NamedType("typing.TypedDict"),)
constants = []
for k, var in v.props.fields.items():
typ = pytd_utils.JoinTypes(
self.value_instance_to_pytd_type(node, p, None, set(), {})
for p in var.data)
for k, val in v.props.fields.items():
typ = self.value_instance_to_pytd_type(node, val, None, set(), {})
if v.props.total and k not in v.props.required:
typ = pytd.GenericType(pytd.NamedType("typing.NotRequired"), (typ,))
elif not v.props.total and k in v.props.required:
Expand Down
6 changes: 3 additions & 3 deletions pytype/overlays/dataclass_overlay.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def decorate(self, node, cls):
continue
kind = ""
init = True
kw_only = False
kw_only = sticky_kwonly
assert typ
if match_classvar(typ):
continue
Expand All @@ -112,8 +112,8 @@ def decorate(self, node, cls):
field = orig.data[0]
orig = field.default
init = field.init
if self.ctx.python_version >= (3, 10):
kw_only = sticky_kwonly if field.kw_only is None else field.kw_only
if field.kw_only is not None:
kw_only = field.kw_only

if orig and orig.data == [self.ctx.convert.none]:
# vm._apply_annotation mostly takes care of checking that the default
Expand Down
59 changes: 28 additions & 31 deletions pytype/overlays/typed_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import dataclasses

from typing import Any, Dict, Optional, Set
from typing import Dict, Optional, Set

from pytype.abstract import abstract
from pytype.abstract import abstract_utils
Expand Down Expand Up @@ -34,7 +34,7 @@ class TypedDictProperties:
"""Collection of typed dict properties passed between various stages."""

name: str
fields: Dict[str, Any]
fields: Dict[str, abstract.BaseValue]
required: Set[str]
total: bool

Expand All @@ -48,27 +48,15 @@ def optional(self):

def add(self, k, v, total):
"""Adds key and value."""
values = []
all_requiredness = set()
for value in v.data:
req = _is_required(value)
if req is None:
values.append(value)
all_requiredness.add(None)
elif isinstance(value, abstract.ParameterizedClass):
values.append(value.formal_type_parameters[abstract_utils.T])
all_requiredness.add(req)
else:
values.append(value.ctx.convert.unsolvable)
all_requiredness.add(req)
if (len(all_requiredness) == 1 and
(requiredness := next(iter(all_requiredness))) is not None):
final_v = v.program.NewVariable(values, [], v.program.entrypoint)
required = requiredness
req = _is_required(v)
if req is None:
value = v
elif isinstance(v, abstract.ParameterizedClass):
value = v.formal_type_parameters[abstract_utils.T]
else:
final_v = v
required = total
self.fields[k] = final_v # pylint: disable=unsupported-assignment-operation
value = v.ctx.convert.unsolvable
required = total if req is None else req
self.fields[k] = value # pylint: disable=unsupported-assignment-operation
if required:
self.required.add(k)

Expand Down Expand Up @@ -122,7 +110,12 @@ def _extract_args(self, args):
name=name, fields={}, required=set(), total=total)
# Force Required/NotRequired evaluation
for k, v in fields.items():
props.add(k, v, total)
try:
value = abstract_utils.get_atomic_value(v)
except abstract_utils.ConversionError:
self.ctx.errorlog.ambiguous_annotation(self.ctx.vm.frames, v.data, k)
value = self.ctx.convert.unsolvable
props.add(k, value, total)
return props

def _validate_bases(self, cls_name, bases):
Expand Down Expand Up @@ -182,8 +175,14 @@ def make_class(self, node, bases, f_locals, total):
ordering=classgen.Ordering.FIRST_ANNOTATE,
ctx=self.ctx)
for k, local in cls_locals.items():
assert local.typ
props.add(k, local.typ, total)
var = local.typ
assert var
try:
typ = abstract_utils.get_atomic_value(var)
except abstract_utils.ConversionError:
self.ctx.errorlog.ambiguous_annotation(self.ctx.vm.frames, var.data, k)
typ = self.ctx.convert.unsolvable
props.add(k, typ, total)

# Process base classes and generate the __init__ signature.
self._validate_bases(cls_name, bases)
Expand All @@ -207,7 +206,7 @@ def make_class_from_pyi(self, cls_name, pytd_cls):
name=name, fields={}, required=set(), total=total)

for c in pytd_cls.constants:
typ = self.ctx.convert.constant_to_var(c.type)
typ = self.ctx.convert.constant_to_value(c.type)
props.add(c.name, typ, total)

# Process base classes and generate the __init__ signature.
Expand Down Expand Up @@ -239,8 +238,7 @@ def _make_init(self, props):
sig = function.Signature.from_param_names(
f"{props.name}.__init__", props.fields.keys(),
kind=pytd.ParameterKind.KWONLY)
sig.annotations = {k: abstract_utils.get_atomic_value(v)
for k, v in props.fields.items()}
sig.annotations = dict(props.fields)
sig.defaults = {k: self.ctx.new_unsolvable(self.ctx.root_node)
for k in props.optional}
return abstract.SimpleFunction(sig, self.ctx)
Expand All @@ -256,8 +254,7 @@ def _new_instance(self, container, node, args):
def instantiate_value(self, node, container):
args = function.Args(())
for name, typ in self.props.fields.items():
args.namedargs[name] = self.ctx.join_variables(
node, [t.instantiate(node) for t in typ.data])
args.namedargs[name] = typ.instantiate(node)
return self._new_instance(container, node, args)

def instantiate(self, node, container=None):
Expand Down Expand Up @@ -301,7 +298,7 @@ def _check_str_key(self, name):

def _check_str_key_value(self, node, name, value_var):
self._check_str_key(name)
typ = abstract_utils.get_atomic_value(self.fields[name])
typ = self.fields[name]
bad = self.ctx.matcher(node).compute_one_match(value_var, typ).bad_matches
for match in bad:
self.ctx.errorlog.annotation_type_mismatch(
Expand Down
3 changes: 2 additions & 1 deletion pytype/overriding_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,8 @@ def is_subtype(this_type, that_type):
"""Return True iff this_type is a subclass of that_type."""
if this_type == ctx.convert.never:
return True # Never is the bottom type, so it matches everything
this_type_instance = this_type.instantiate(ctx.root_node, None)
this_type_instance = this_type.instantiate(
ctx.root_node, container=abstract_utils.DUMMY_CONTAINER)
return matcher.compute_one_match(this_type_instance, that_type).success

check_result = (
Expand Down
20 changes: 20 additions & 0 deletions pytype/tests/test_dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -761,6 +761,26 @@ class A:
def __init__(self, a1: int, a3: int, *, a2: int = ...) -> None: ...
""")

@test_utils.skipBeforePy((3, 10), "KW_ONLY is new in 3.10")
def test_kwonly_and_nonfield_default(self):
ty = self.Infer("""
import dataclasses
@dataclasses.dataclass
class C:
_: dataclasses.KW_ONLY
x: int = 0
y: str
""")
self.assertTypesMatchPytd(ty, """
import dataclasses
@dataclasses.dataclass
class C:
x: int = ...
y: str
_: dataclasses.KW_ONLY
def __init__(self, *, x: int = ..., y: str) -> None: ...
""")

def test_star_import(self):
with self.DepTree([("foo.pyi", """
import dataclasses
Expand Down
28 changes: 28 additions & 0 deletions pytype/tests/test_flax_overlay.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,34 @@ def __init__(
def replace(self: _TBaz, **kwargs) -> _TBaz: ...
""")

@test_utils.skipBeforePy((3, 10), "KW_ONLY is new in 3.10")
def test_kwonly(self):
with test_utils.Tempdir() as d:
self._setup_linen_pyi(d)
ty = self.Infer("""
import dataclasses
from flax import linen as nn
class C(nn.Module):
_: dataclasses.KW_ONLY
x: int = 0
y: str
""", pythonpath=[d.path])
self.assertTypesMatchPytd(ty, """
import dataclasses
from flax import linen as nn
from typing import Any, TypeVar
_TC = TypeVar('_TC', bound=C)
@dataclasses.dataclass
class C(nn.module.Module):
x: int = ...
y: str
_: dataclasses.KW_ONLY
def __init__(self, *, x: int = ..., y: str, name: str = ..., parent: Any = ...) -> None: ...
def replace(self: _TC, **kwargs) -> _TC: ...
""")


if __name__ == "__main__":
test_base.main()
18 changes: 18 additions & 0 deletions pytype/tests/test_typed_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,14 @@ def f() -> TD:
return __any_object__
""")

def test_duplicate_key(self):
self.CheckWithErrors("""
from typing_extensions import TypedDict
class TD(TypedDict): # invalid-annotation
x: int
x: str
""")


class TypedDictFunctionalTest(test_base.BaseTest):
"""Tests for typing.TypedDict functional constructor."""
Expand Down Expand Up @@ -458,6 +466,16 @@ class X(TypedDict, total=False):
name: str
""")

def test_ambiguous_field_type(self):
self.CheckWithErrors("""
from typing_extensions import TypedDict
if __random__:
v = str
else:
v = int
X = TypedDict('X', {'k': v}) # invalid-annotation
""")


_SINGLE = """
from typing import TypedDict
Expand Down
11 changes: 11 additions & 0 deletions pytype/tests/test_typing_self.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,17 @@ def f(self) -> Self:
return self
""")

def test_signature_compatibility(self):
self.Check("""
from typing_extensions import Self
class Parent:
def add(self, other: Self) -> Self:
return self
class Child(Parent):
def add(self, other: Self) -> Self:
return self
""")


class SelfPyiTest(test_base.BaseTest):
"""Tests for typing.Self usage in type stubs."""
Expand Down

0 comments on commit c6869a1

Please sign in to comment.