Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
KotlinIsland committed Sep 14, 2024
1 parent b27b64c commit 62f8c15
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 16 deletions.
38 changes: 27 additions & 11 deletions src/basedtyping/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ def implicit_tuple(self, *, value=True) -> typing.Iterator[None]:

def visit_Subscript(self, node: ast.Subscript) -> ast.AST:
node_type = self.eval_type(node.value)
if self.eval_type(node.value) is typing_extensions.Literal:
return node
if node_type is typing_extensions.Annotated:
slice_ = node.slice
if isinstance(slice_, ast.Tuple):
Expand Down Expand Up @@ -148,6 +150,13 @@ def visit_Name(self, node: ast.Name) -> ast.AST:

def visit_Constant(self, node: ast.Constant) -> ast.AST:
value = cast(object, node.value)
if not self.string_literals and isinstance(value, str):
result = self._transform(
basedtyping.ForwardRef(value),
)
if isinstance(result, ast.Expression):
return result.body
return result
if isinstance(value, int) or (self.string_literals and isinstance(value, str)):
return self._literal(node)
return node
Expand Down Expand Up @@ -196,6 +205,23 @@ def visit_BinOp(self, node: ast.BinOp) -> ast.AST:
)
return node

def _transform(
self,
value: typing.ForwardRef,
) -> ast.AST:
tree: ast.AST
try:
tree = ast.parse(value.__forward_arg__, mode="eval")
except SyntaxError:
arg = value.__forward_arg__.lstrip()
if arg.startswith(("def ", "def(")):
arg = arg[3:].lstrip()
tree = ast.parse(arg, mode="func_type")

tree = self.visit(tree)
assert isinstance(tree, (ast.FunctionType, ast.Expression, ast.expr))
return tree


def _eval_direct(
value: object,
Expand All @@ -218,16 +244,6 @@ def eval_type_based(
"""
if not isinstance(value, typing.ForwardRef):
return value
tree: ast.AST
try:
tree = ast.parse(value.__forward_arg__, mode="eval")
except SyntaxError:
arg = value.__forward_arg__
if arg.startswith("def"):
arg = arg[3:]
tree = ast.parse(arg.lstrip(), mode="func_type")

transformer = CringeTransformer(globalns, localns, string_literals=string_literals)
tree = transformer.visit(tree)
assert isinstance(tree, (ast.FunctionType, ast.Expression, ast.expr))
tree = transformer._transform(value)
return transformer.eval_type(tree, original_ref=value)
20 changes: 15 additions & 5 deletions tests/test_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,14 @@ def test_literal_nested():
validate("List[(1, 2),]", List[Tuple[Literal[1], Literal[2]]])


def test_literal_str():
validate("'int'", int)
validate("Literal['int']", Literal["int"])
validate("'int'", Literal["int"], string_literals=True)
validate("Literal['int']", Literal["int"], string_literals=True)
def test_literal_str_forwardref():
validate("'1'", Literal[1])
validate("Literal['1']", Literal["1"])


def test_literal_str_literal():
validate("'1'", Literal["1"], string_literals=True)
validate("Literal['1']", Literal["1"], string_literals=True)


class E(Enum):
Expand Down Expand Up @@ -88,6 +91,13 @@ def test_function():
validate("FunctionType[[str], int]", Callable[[str], int])


def_ = int


def test_adversarial_function():
validate("def_ | '() -> int'", Union[def_, Callable[[], int]])


def test_functiontype():
validate("FunctionType[[str], int]", Callable[[str], int])

Expand Down

0 comments on commit 62f8c15

Please sign in to comment.