diff --git a/pytype/matcher.py b/pytype/matcher.py index 964730a63..c18e12881 100644 --- a/pytype/matcher.py +++ b/pytype/matcher.py @@ -456,12 +456,17 @@ def match_var_against_type(self, var, other_type, subst, view): def _match_type_param_against_type_param(self, t1, t2, subst, view): """Match a TypeVar against another TypeVar.""" - if t1.full_name == "typing.Self" and not t1.bound: - # We're matching a Self instance before it's been bound to its containing - # class. We know it should be bound but not to what, so `Any` is the best - # we can do. - t1 = t1.copy() - t1.bound = self.ctx.convert.unsolvable + if t1.full_name == "typing.Self": + if t2.full_name == "typing.Self": + # Self always matches itself. We check for this explicitly because Self + # instances may have their bounds set to incompatible classes. + return subst + elif not t1.bound: + # We're matching a Self instance before it's been bound to its + # containing class. We know it should be bound but not to what, so `Any` + # is the best we can do. + t1 = t1.copy() + t1.bound = self.ctx.convert.unsolvable if t2.constraints: assert not t2.bound # constraints and bounds are mutually exclusive # We only check the constraints for t1, not the bound. We wouldn't know @@ -1461,7 +1466,7 @@ def _match_dict_against_typed_dict( for k, v in left.pyval.items(): if k not in fields: continue - typ = abstract_utils.get_atomic_value(fields[k]) + typ = fields[k] match_result = self.compute_one_match(v, typ) if not match_result.success: bad.append((k, match_result.bad_matches)) diff --git a/pytype/output.py b/pytype/output.py index 3d8ed112a..1babd3486 100644 --- a/pytype/output.py +++ b/pytype/output.py @@ -1016,10 +1016,8 @@ def _typed_dict_to_def(self, node, v, name): keywords.append(("total", pytd.Literal(False))) bases = (pytd.NamedType("typing.TypedDict"),) constants = [] - for k, var in v.props.fields.items(): - typ = pytd_utils.JoinTypes( - self.value_instance_to_pytd_type(node, p, None, set(), {}) - for p in var.data) + for k, val in v.props.fields.items(): + typ = self.value_instance_to_pytd_type(node, val, None, set(), {}) if v.props.total and k not in v.props.required: typ = pytd.GenericType(pytd.NamedType("typing.NotRequired"), (typ,)) elif not v.props.total and k in v.props.required: diff --git a/pytype/overlays/dataclass_overlay.py b/pytype/overlays/dataclass_overlay.py index 5cbde5358..33a356c58 100644 --- a/pytype/overlays/dataclass_overlay.py +++ b/pytype/overlays/dataclass_overlay.py @@ -97,7 +97,7 @@ def decorate(self, node, cls): continue kind = "" init = True - kw_only = False + kw_only = sticky_kwonly assert typ if match_classvar(typ): continue @@ -112,8 +112,8 @@ def decorate(self, node, cls): field = orig.data[0] orig = field.default init = field.init - if self.ctx.python_version >= (3, 10): - kw_only = sticky_kwonly if field.kw_only is None else field.kw_only + if field.kw_only is not None: + kw_only = field.kw_only if orig and orig.data == [self.ctx.convert.none]: # vm._apply_annotation mostly takes care of checking that the default diff --git a/pytype/overlays/typed_dict.py b/pytype/overlays/typed_dict.py index 46a6ec569..8cee96637 100644 --- a/pytype/overlays/typed_dict.py +++ b/pytype/overlays/typed_dict.py @@ -2,7 +2,7 @@ import dataclasses -from typing import Any, Dict, Optional, Set +from typing import Dict, Optional, Set from pytype.abstract import abstract from pytype.abstract import abstract_utils @@ -34,7 +34,7 @@ class TypedDictProperties: """Collection of typed dict properties passed between various stages.""" name: str - fields: Dict[str, Any] + fields: Dict[str, abstract.BaseValue] required: Set[str] total: bool @@ -48,27 +48,15 @@ def optional(self): def add(self, k, v, total): """Adds key and value.""" - values = [] - all_requiredness = set() - for value in v.data: - req = _is_required(value) - if req is None: - values.append(value) - all_requiredness.add(None) - elif isinstance(value, abstract.ParameterizedClass): - values.append(value.formal_type_parameters[abstract_utils.T]) - all_requiredness.add(req) - else: - values.append(value.ctx.convert.unsolvable) - all_requiredness.add(req) - if (len(all_requiredness) == 1 and - (requiredness := next(iter(all_requiredness))) is not None): - final_v = v.program.NewVariable(values, [], v.program.entrypoint) - required = requiredness + req = _is_required(v) + if req is None: + value = v + elif isinstance(v, abstract.ParameterizedClass): + value = v.formal_type_parameters[abstract_utils.T] else: - final_v = v - required = total - self.fields[k] = final_v # pylint: disable=unsupported-assignment-operation + value = v.ctx.convert.unsolvable + required = total if req is None else req + self.fields[k] = value # pylint: disable=unsupported-assignment-operation if required: self.required.add(k) @@ -122,7 +110,12 @@ def _extract_args(self, args): name=name, fields={}, required=set(), total=total) # Force Required/NotRequired evaluation for k, v in fields.items(): - props.add(k, v, total) + try: + value = abstract_utils.get_atomic_value(v) + except abstract_utils.ConversionError: + self.ctx.errorlog.ambiguous_annotation(self.ctx.vm.frames, v.data, k) + value = self.ctx.convert.unsolvable + props.add(k, value, total) return props def _validate_bases(self, cls_name, bases): @@ -182,8 +175,14 @@ def make_class(self, node, bases, f_locals, total): ordering=classgen.Ordering.FIRST_ANNOTATE, ctx=self.ctx) for k, local in cls_locals.items(): - assert local.typ - props.add(k, local.typ, total) + var = local.typ + assert var + try: + typ = abstract_utils.get_atomic_value(var) + except abstract_utils.ConversionError: + self.ctx.errorlog.ambiguous_annotation(self.ctx.vm.frames, var.data, k) + typ = self.ctx.convert.unsolvable + props.add(k, typ, total) # Process base classes and generate the __init__ signature. self._validate_bases(cls_name, bases) @@ -207,7 +206,7 @@ def make_class_from_pyi(self, cls_name, pytd_cls): name=name, fields={}, required=set(), total=total) for c in pytd_cls.constants: - typ = self.ctx.convert.constant_to_var(c.type) + typ = self.ctx.convert.constant_to_value(c.type) props.add(c.name, typ, total) # Process base classes and generate the __init__ signature. @@ -239,8 +238,7 @@ def _make_init(self, props): sig = function.Signature.from_param_names( f"{props.name}.__init__", props.fields.keys(), kind=pytd.ParameterKind.KWONLY) - sig.annotations = {k: abstract_utils.get_atomic_value(v) - for k, v in props.fields.items()} + sig.annotations = dict(props.fields) sig.defaults = {k: self.ctx.new_unsolvable(self.ctx.root_node) for k in props.optional} return abstract.SimpleFunction(sig, self.ctx) @@ -256,8 +254,7 @@ def _new_instance(self, container, node, args): def instantiate_value(self, node, container): args = function.Args(()) for name, typ in self.props.fields.items(): - args.namedargs[name] = self.ctx.join_variables( - node, [t.instantiate(node) for t in typ.data]) + args.namedargs[name] = typ.instantiate(node) return self._new_instance(container, node, args) def instantiate(self, node, container=None): @@ -301,7 +298,7 @@ def _check_str_key(self, name): def _check_str_key_value(self, node, name, value_var): self._check_str_key(name) - typ = abstract_utils.get_atomic_value(self.fields[name]) + typ = self.fields[name] bad = self.ctx.matcher(node).compute_one_match(value_var, typ).bad_matches for match in bad: self.ctx.errorlog.annotation_type_mismatch( diff --git a/pytype/overriding_checks.py b/pytype/overriding_checks.py index bc3afc686..eb629dcdd 100644 --- a/pytype/overriding_checks.py +++ b/pytype/overriding_checks.py @@ -414,7 +414,8 @@ def is_subtype(this_type, that_type): """Return True iff this_type is a subclass of that_type.""" if this_type == ctx.convert.never: return True # Never is the bottom type, so it matches everything - this_type_instance = this_type.instantiate(ctx.root_node, None) + this_type_instance = this_type.instantiate( + ctx.root_node, container=abstract_utils.DUMMY_CONTAINER) return matcher.compute_one_match(this_type_instance, that_type).success check_result = ( diff --git a/pytype/tests/test_dataclasses.py b/pytype/tests/test_dataclasses.py index 7a09cfbc1..8c46e2697 100644 --- a/pytype/tests/test_dataclasses.py +++ b/pytype/tests/test_dataclasses.py @@ -761,6 +761,26 @@ class A: def __init__(self, a1: int, a3: int, *, a2: int = ...) -> None: ... """) + @test_utils.skipBeforePy((3, 10), "KW_ONLY is new in 3.10") + def test_kwonly_and_nonfield_default(self): + ty = self.Infer(""" + import dataclasses + @dataclasses.dataclass + class C: + _: dataclasses.KW_ONLY + x: int = 0 + y: str + """) + self.assertTypesMatchPytd(ty, """ + import dataclasses + @dataclasses.dataclass + class C: + x: int = ... + y: str + _: dataclasses.KW_ONLY + def __init__(self, *, x: int = ..., y: str) -> None: ... + """) + def test_star_import(self): with self.DepTree([("foo.pyi", """ import dataclasses diff --git a/pytype/tests/test_flax_overlay.py b/pytype/tests/test_flax_overlay.py index 40b57fff0..b80602680 100644 --- a/pytype/tests/test_flax_overlay.py +++ b/pytype/tests/test_flax_overlay.py @@ -266,6 +266,34 @@ def __init__( def replace(self: _TBaz, **kwargs) -> _TBaz: ... """) + @test_utils.skipBeforePy((3, 10), "KW_ONLY is new in 3.10") + def test_kwonly(self): + with test_utils.Tempdir() as d: + self._setup_linen_pyi(d) + ty = self.Infer(""" + import dataclasses + from flax import linen as nn + class C(nn.Module): + _: dataclasses.KW_ONLY + x: int = 0 + y: str + """, pythonpath=[d.path]) + self.assertTypesMatchPytd(ty, """ + import dataclasses + from flax import linen as nn + from typing import Any, TypeVar + + _TC = TypeVar('_TC', bound=C) + + @dataclasses.dataclass + class C(nn.module.Module): + x: int = ... + y: str + _: dataclasses.KW_ONLY + def __init__(self, *, x: int = ..., y: str, name: str = ..., parent: Any = ...) -> None: ... + def replace(self: _TC, **kwargs) -> _TC: ... + """) + if __name__ == "__main__": test_base.main() diff --git a/pytype/tests/test_typed_dict.py b/pytype/tests/test_typed_dict.py index 5f89d6646..302b391d9 100644 --- a/pytype/tests/test_typed_dict.py +++ b/pytype/tests/test_typed_dict.py @@ -385,6 +385,14 @@ def f() -> TD: return __any_object__ """) + def test_duplicate_key(self): + self.CheckWithErrors(""" + from typing_extensions import TypedDict + class TD(TypedDict): # invalid-annotation + x: int + x: str + """) + class TypedDictFunctionalTest(test_base.BaseTest): """Tests for typing.TypedDict functional constructor.""" @@ -458,6 +466,16 @@ class X(TypedDict, total=False): name: str """) + def test_ambiguous_field_type(self): + self.CheckWithErrors(""" + from typing_extensions import TypedDict + if __random__: + v = str + else: + v = int + X = TypedDict('X', {'k': v}) # invalid-annotation + """) + _SINGLE = """ from typing import TypedDict diff --git a/pytype/tests/test_typing_self.py b/pytype/tests/test_typing_self.py index 61d5410a6..6f88ec37e 100644 --- a/pytype/tests/test_typing_self.py +++ b/pytype/tests/test_typing_self.py @@ -281,6 +281,17 @@ def f(self) -> Self: return self """) + def test_signature_compatibility(self): + self.Check(""" + from typing_extensions import Self + class Parent: + def add(self, other: Self) -> Self: + return self + class Child(Parent): + def add(self, other: Self) -> Self: + return self + """) + class SelfPyiTest(test_base.BaseTest): """Tests for typing.Self usage in type stubs."""