Skip to content

Commit

Permalink
Merge pull request #1600 from google/google_sync
Browse files Browse the repository at this point in the history
Add 'default' field to pytd.TypeParameter.
  • Loading branch information
rchen152 authored Mar 19, 2024
2 parents b899682 + bf32914 commit a4d7003
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 10 deletions.
8 changes: 5 additions & 3 deletions pytype/pyi/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,11 +468,13 @@ def add_type_variable(self, name, tvar):
raise _ParseError(f"{tvar.kind} name needs to be {tvar.name!r} "
f"(not {name!r})")
bound = tvar.bound
if isinstance(bound, str):
bound = pytd.NamedType(bound)
constraints = tuple(tvar.constraints) if tvar.constraints else ()
if isinstance(tvar.default, list):
default = tuple(tvar.default)
else:
default = tvar.default
self.type_params.append(pytd_type(
name=name, constraints=constraints, bound=bound))
name=name, constraints=constraints, bound=bound, default=default))

def add_import(self, from_package, import_list):
"""Add an import.
Expand Down
16 changes: 11 additions & 5 deletions pytype/pyi/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import re
import sys
import tokenize
from typing import Any, List, Optional, Tuple, cast
from typing import Any, List, Optional, Tuple, Union, cast

from pytype.ast import debug
from pytype.pyi import conditions
Expand Down Expand Up @@ -61,6 +61,7 @@ class _TypeVariable:
name: str
bound: Optional[pytd.Type]
constraints: List[pytd.Type]
default: Optional[Union[pytd.Type, List[pytd.Type]]]

@classmethod
def from_call(cls, kind: str, node: astlib.Call):
Expand All @@ -72,18 +73,20 @@ def from_call(cls, kind: str, node: astlib.Call):
if not types.Pyval.is_str(name):
raise ParseError(f"Bad arguments to {kind}")
bound = None
# 'bound' is the only keyword argument we currently use.
default = None
# TODO(rechen): We should enforce the PEP 484 guideline that
# len(constraints) != 1. However, this guideline is currently violated
# in typeshed (see https://github.com/python/typeshed/pull/806).
kws = {x.arg for x in node.keywords}
extra = kws - {"bound", "covariant", "contravariant"}
extra = kws - {"bound", "covariant", "contravariant", "default"}
if extra:
raise ParseError(f"Unrecognized keyword(s): {', '.join(extra)}")
for kw in node.keywords:
if kw.arg == "bound":
bound = kw.value
return cls(kind, name.value, bound, constraints)
elif kw.arg == "default":
default = kw.value
return cls(kind, name.value, bound, constraints, default)

#------------------------------------------------------
# Main tree visitor and generator code
Expand Down Expand Up @@ -674,6 +677,8 @@ def _convert_typevar_args(self, node: astlib.Call):
for kw in node.keywords:
if kw.arg == "bound":
kw.value = self.annotation_visitor.visit(kw.value)
elif kw.arg == "default":
kw.value = self.annotation_visitor.visit(kw.value)

def _convert_typed_dict_args(self, node: astlib.Call):
for fields in node.args[1:]:
Expand All @@ -682,7 +687,8 @@ def _convert_typed_dict_args(self, node: astlib.Call):
def enter_Call(self, node):
node.func = self.annotation_visitor.visit(node.func)
func = node.func.name or ""
if self.defs.matches_type(func, ("typing.TypeVar", "typing.ParamSpec")):
if self.defs.matches_type(func, ("typing.TypeVar", "typing.ParamSpec",
"typing.TypeVarTuple")):
self._convert_typevar_args(node)
elif self.defs.matches_type(func, "typing.NamedTuple"):
self._convert_typing_namedtuple_args(node)
Expand Down
25 changes: 25 additions & 0 deletions pytype/pyi/parser_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3526,5 +3526,30 @@ def f(x: Tuple[Any, ...]) -> Any: ...
""")


class TypeParameterDefaultTest(parser_test_base.ParserTestBase):

def test_typevar(self):
self.check("""
from typing_extensions import TypeVar
T = TypeVar('T', default=int)
""")

def test_paramspec(self):
self.check("""
from typing_extensions import ParamSpec
P = ParamSpec('P', default=[str, int])
""")

def test_typevartuple(self):
self.check("""
from typing_extensions import TypeVarTuple, Unpack
Ts = TypeVarTuple('Ts', default=Unpack[tuple[str, int]])
""", """
from typing_extensions import TypeVarTuple, TypeVarTuple as Ts, Unpack
""")


if __name__ == "__main__":
unittest.main()
5 changes: 5 additions & 0 deletions pytype/pytd/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,11 @@ def _FormatTypeParams(self, type_params):
args += [self.Print(c) for c in t.constraints]
if t.bound:
args.append(f"bound={self.Print(t.bound)}")
if isinstance(t.default, tuple):
args.append(
f"default=[{', '.join(self.Print(d) for d in t.default)}]")
elif t.default:
args.append(f"default={self.Print(t.default)}")
if isinstance(t, pytd.ParamSpec):
typename = self._LookupTypingMember("ParamSpec")
else:
Expand Down
1 change: 1 addition & 0 deletions pytype/pytd/pytd.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,7 @@ def f(x: T) -> T
name: str
constraints: Tuple[TypeU, ...] = ()
bound: Optional[TypeU] = None
default: Optional[Union[TypeU, Tuple[TypeU, ...]]] = None
scope: Optional[str] = None

def __lt__(self, other):
Expand Down
3 changes: 1 addition & 2 deletions pytype/pytd/visitors_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,8 +552,7 @@ class A(Dict[T, T], Generic[T]): pass
""")
a = ast.Lookup("A")
self.assertEqual(
(pytd.TemplateItem(pytd.TypeParameter("T", (), None, "A")),),
a.template)
(pytd.TemplateItem(pytd.TypeParameter("T", scope="A")),), a.template)

def test_adjust_type_parameters_with_duplicates_in_generic(self):
src = textwrap.dedent("""
Expand Down

0 comments on commit a4d7003

Please sign in to comment.