diff --git a/pytype/pyi/definitions.py b/pytype/pyi/definitions.py index d6922a3cc..7c5a9d83b 100644 --- a/pytype/pyi/definitions.py +++ b/pytype/pyi/definitions.py @@ -468,11 +468,13 @@ def add_type_variable(self, name, tvar): raise _ParseError(f"{tvar.kind} name needs to be {tvar.name!r} " f"(not {name!r})") bound = tvar.bound - if isinstance(bound, str): - bound = pytd.NamedType(bound) constraints = tuple(tvar.constraints) if tvar.constraints else () + if isinstance(tvar.default, list): + default = tuple(tvar.default) + else: + default = tvar.default self.type_params.append(pytd_type( - name=name, constraints=constraints, bound=bound)) + name=name, constraints=constraints, bound=bound, default=default)) def add_import(self, from_package, import_list): """Add an import. diff --git a/pytype/pyi/parser.py b/pytype/pyi/parser.py index 89ff9a8ed..c1af66698 100644 --- a/pytype/pyi/parser.py +++ b/pytype/pyi/parser.py @@ -8,7 +8,7 @@ import re import sys import tokenize -from typing import Any, List, Optional, Tuple, cast +from typing import Any, List, Optional, Tuple, Union, cast from pytype.ast import debug from pytype.pyi import conditions @@ -61,6 +61,7 @@ class _TypeVariable: name: str bound: Optional[pytd.Type] constraints: List[pytd.Type] + default: Optional[Union[pytd.Type, List[pytd.Type]]] @classmethod def from_call(cls, kind: str, node: astlib.Call): @@ -72,18 +73,20 @@ def from_call(cls, kind: str, node: astlib.Call): if not types.Pyval.is_str(name): raise ParseError(f"Bad arguments to {kind}") bound = None - # 'bound' is the only keyword argument we currently use. + default = None # TODO(rechen): We should enforce the PEP 484 guideline that # len(constraints) != 1. However, this guideline is currently violated # in typeshed (see https://github.com/python/typeshed/pull/806). kws = {x.arg for x in node.keywords} - extra = kws - {"bound", "covariant", "contravariant"} + extra = kws - {"bound", "covariant", "contravariant", "default"} if extra: raise ParseError(f"Unrecognized keyword(s): {', '.join(extra)}") for kw in node.keywords: if kw.arg == "bound": bound = kw.value - return cls(kind, name.value, bound, constraints) + elif kw.arg == "default": + default = kw.value + return cls(kind, name.value, bound, constraints, default) #------------------------------------------------------ # Main tree visitor and generator code @@ -674,6 +677,8 @@ def _convert_typevar_args(self, node: astlib.Call): for kw in node.keywords: if kw.arg == "bound": kw.value = self.annotation_visitor.visit(kw.value) + elif kw.arg == "default": + kw.value = self.annotation_visitor.visit(kw.value) def _convert_typed_dict_args(self, node: astlib.Call): for fields in node.args[1:]: @@ -682,7 +687,8 @@ def _convert_typed_dict_args(self, node: astlib.Call): def enter_Call(self, node): node.func = self.annotation_visitor.visit(node.func) func = node.func.name or "" - if self.defs.matches_type(func, ("typing.TypeVar", "typing.ParamSpec")): + if self.defs.matches_type(func, ("typing.TypeVar", "typing.ParamSpec", + "typing.TypeVarTuple")): self._convert_typevar_args(node) elif self.defs.matches_type(func, "typing.NamedTuple"): self._convert_typing_namedtuple_args(node) diff --git a/pytype/pyi/parser_test.py b/pytype/pyi/parser_test.py index 1cd278aad..bd3aa2056 100644 --- a/pytype/pyi/parser_test.py +++ b/pytype/pyi/parser_test.py @@ -3526,5 +3526,30 @@ def f(x: Tuple[Any, ...]) -> Any: ... """) +class TypeParameterDefaultTest(parser_test_base.ParserTestBase): + + def test_typevar(self): + self.check(""" + from typing_extensions import TypeVar + + T = TypeVar('T', default=int) + """) + + def test_paramspec(self): + self.check(""" + from typing_extensions import ParamSpec + + P = ParamSpec('P', default=[str, int]) + """) + + def test_typevartuple(self): + self.check(""" + from typing_extensions import TypeVarTuple, Unpack + Ts = TypeVarTuple('Ts', default=Unpack[tuple[str, int]]) + """, """ + from typing_extensions import TypeVarTuple, TypeVarTuple as Ts, Unpack + """) + + if __name__ == "__main__": unittest.main() diff --git a/pytype/pytd/printer.py b/pytype/pytd/printer.py index 9ae082226..226255c74 100644 --- a/pytype/pytd/printer.py +++ b/pytype/pytd/printer.py @@ -209,6 +209,11 @@ def _FormatTypeParams(self, type_params): args += [self.Print(c) for c in t.constraints] if t.bound: args.append(f"bound={self.Print(t.bound)}") + if isinstance(t.default, tuple): + args.append( + f"default=[{', '.join(self.Print(d) for d in t.default)}]") + elif t.default: + args.append(f"default={self.Print(t.default)}") if isinstance(t, pytd.ParamSpec): typename = self._LookupTypingMember("ParamSpec") else: diff --git a/pytype/pytd/pytd.py b/pytype/pytd/pytd.py index 56f505e77..3a6b4a33a 100644 --- a/pytype/pytd/pytd.py +++ b/pytype/pytd/pytd.py @@ -344,6 +344,7 @@ def f(x: T) -> T name: str constraints: Tuple[TypeU, ...] = () bound: Optional[TypeU] = None + default: Optional[Union[TypeU, Tuple[TypeU, ...]]] = None scope: Optional[str] = None def __lt__(self, other): diff --git a/pytype/pytd/visitors_test.py b/pytype/pytd/visitors_test.py index 28bcd4846..576ca2791 100644 --- a/pytype/pytd/visitors_test.py +++ b/pytype/pytd/visitors_test.py @@ -552,8 +552,7 @@ class A(Dict[T, T], Generic[T]): pass """) a = ast.Lookup("A") self.assertEqual( - (pytd.TemplateItem(pytd.TypeParameter("T", (), None, "A")),), - a.template) + (pytd.TemplateItem(pytd.TypeParameter("T", scope="A")),), a.template) def test_adjust_type_parameters_with_duplicates_in_generic(self): src = textwrap.dedent("""