From 62f8c15688799d3df2cfb7cb48ccf3d5dcd89cf3 Mon Sep 17 00:00:00 2001 From: KotlinIsland Date: Sun, 15 Sep 2024 00:08:27 +1000 Subject: [PATCH] wip --- src/basedtyping/transformer.py | 38 ++++++++++++++++++++++++---------- tests/test_transformer.py | 20 +++++++++++++----- 2 files changed, 42 insertions(+), 16 deletions(-) diff --git a/src/basedtyping/transformer.py b/src/basedtyping/transformer.py index 39e4c07..cae9edb 100644 --- a/src/basedtyping/transformer.py +++ b/src/basedtyping/transformer.py @@ -109,6 +109,8 @@ def implicit_tuple(self, *, value=True) -> typing.Iterator[None]: def visit_Subscript(self, node: ast.Subscript) -> ast.AST: node_type = self.eval_type(node.value) + if self.eval_type(node.value) is typing_extensions.Literal: + return node if node_type is typing_extensions.Annotated: slice_ = node.slice if isinstance(slice_, ast.Tuple): @@ -148,6 +150,13 @@ def visit_Name(self, node: ast.Name) -> ast.AST: def visit_Constant(self, node: ast.Constant) -> ast.AST: value = cast(object, node.value) + if not self.string_literals and isinstance(value, str): + result = self._transform( + basedtyping.ForwardRef(value), + ) + if isinstance(result, ast.Expression): + return result.body + return result if isinstance(value, int) or (self.string_literals and isinstance(value, str)): return self._literal(node) return node @@ -196,6 +205,23 @@ def visit_BinOp(self, node: ast.BinOp) -> ast.AST: ) return node + def _transform( + self, + value: typing.ForwardRef, + ) -> ast.AST: + tree: ast.AST + try: + tree = ast.parse(value.__forward_arg__, mode="eval") + except SyntaxError: + arg = value.__forward_arg__.lstrip() + if arg.startswith(("def ", "def(")): + arg = arg[3:].lstrip() + tree = ast.parse(arg, mode="func_type") + + tree = self.visit(tree) + assert isinstance(tree, (ast.FunctionType, ast.Expression, ast.expr)) + return tree + def _eval_direct( value: object, @@ -218,16 +244,6 @@ def eval_type_based( """ if not isinstance(value, typing.ForwardRef): return value - tree: ast.AST - try: - tree = ast.parse(value.__forward_arg__, mode="eval") - except SyntaxError: - arg = value.__forward_arg__ - if arg.startswith("def"): - arg = arg[3:] - tree = ast.parse(arg.lstrip(), mode="func_type") - transformer = CringeTransformer(globalns, localns, string_literals=string_literals) - tree = transformer.visit(tree) - assert isinstance(tree, (ast.FunctionType, ast.Expression, ast.expr)) + tree = transformer._transform(value) return transformer.eval_type(tree, original_ref=value) diff --git a/tests/test_transformer.py b/tests/test_transformer.py index 3a38e31..fea7f5e 100644 --- a/tests/test_transformer.py +++ b/tests/test_transformer.py @@ -42,11 +42,14 @@ def test_literal_nested(): validate("List[(1, 2),]", List[Tuple[Literal[1], Literal[2]]]) -def test_literal_str(): - validate("'int'", int) - validate("Literal['int']", Literal["int"]) - validate("'int'", Literal["int"], string_literals=True) - validate("Literal['int']", Literal["int"], string_literals=True) +def test_literal_str_forwardref(): + validate("'1'", Literal[1]) + validate("Literal['1']", Literal["1"]) + + +def test_literal_str_literal(): + validate("'1'", Literal["1"], string_literals=True) + validate("Literal['1']", Literal["1"], string_literals=True) class E(Enum): @@ -88,6 +91,13 @@ def test_function(): validate("FunctionType[[str], int]", Callable[[str], int]) +def_ = int + + +def test_adversarial_function(): + validate("def_ | '() -> int'", Union[def_, Callable[[], int]]) + + def test_functiontype(): validate("FunctionType[[str], int]", Callable[[str], int])