Skip to content
This repository was archived by the owner on Feb 2, 2024. It is now read-only.

Commit e6c4aca

Browse files
Getting local variable annotations of function and generics (#895)
1 parent 30e5395 commit e6c4aca

File tree

2 files changed

+133
-0
lines changed

2 files changed

+133
-0
lines changed
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import ast
2+
import inspect
3+
import textwrap
4+
from pathlib import Path
5+
6+
7+
def get_variable_annotations(func):
8+
"""Get local variable annotations from the function."""
9+
module_path = inspect.getfile(func)
10+
func_global_variables = func.__globals__
11+
func_global_variables.update(inspect.getclosurevars(func).nonlocals)
12+
module_name = (Path(f'{module_path}').stem)
13+
func_code = inspect.getsource(func)
14+
func_code_with_dedent = textwrap.dedent(func_code)
15+
func_tree = ast.parse(func_code_with_dedent)
16+
analyzer = Analyzer(module_name, func_global_variables)
17+
analyzer.visit(func_tree)
18+
return analyzer.locals_parameter
19+
20+
21+
class Analyzer(ast.NodeVisitor):
22+
def __init__(self, module_name, func_global_variables):
23+
self.locals_parameter = {}
24+
self.module_name = module_name
25+
self.global_parameter = func_global_variables
26+
27+
def visit_AnnAssign(self, node):
28+
target, annotation = node.target, node.annotation
29+
if isinstance(annotation, ast.Subscript): # containers and generics
30+
try:
31+
container_name = annotation.value.id
32+
module_container_name = ''
33+
except AttributeError: # typing.
34+
module_container_name = annotation.value.value.id + '.'
35+
container_name = annotation.value.attr
36+
if isinstance(annotation.slice.value, ast.Tuple):
37+
types_as_str = ','.join(elt.id for elt in annotation.slice.value.elts)
38+
exec_variables = f'{target.id} = [{module_container_name}{container_name}[{types_as_str}]]'
39+
else:
40+
exec_variables = f'{target.id} = [{module_container_name}{container_name}[{annotation.slice.value.id}]]'
41+
exec(exec_variables, self.global_parameter, self.locals_parameter)
42+
else: # not containers
43+
exec(f'{target.id} = [{annotation.id}]', self.global_parameter, self.locals_parameter)
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
from local_variable_type_annotations import get_variable_annotations
2+
import unittest
3+
from typing import Any, Union, List, Tuple, Dict, Iterable, Iterator, Generic, TypeVar
4+
import typing
5+
6+
T = TypeVar('T')
7+
S = TypeVar('S', int, str)
8+
G = Generic[T, S]
9+
10+
11+
class A():
12+
...
13+
14+
15+
class B():
16+
...
17+
18+
19+
class C():
20+
...
21+
22+
23+
class TestAst(unittest.TestCase):
24+
maxDiff = None
25+
26+
def test_get_variable_annotations_standard_types(self):
27+
def test_func():
28+
t = 1
29+
t_int: int
30+
t_str: str
31+
t_float: float
32+
t_bool: bool
33+
t_bytes: bytes
34+
t_list: list
35+
t_dict: dict
36+
t_tuple: tuple
37+
t_set: set
38+
result = get_variable_annotations(test_func)
39+
expected_result = {'t_int': [int], 't_str': [str], 't_float': [float], 't_bool': [bool], 't_bytes': [bytes],
40+
't_list': [list], 't_dict': [dict], 't_tuple': [tuple], 't_set': [set]}
41+
self.assertEqual(result, expected_result)
42+
43+
def test_get_variable_annotations_generic_types(self):
44+
def test_func():
45+
t_any: Any
46+
t_union: Union[str, bytes]
47+
t_optional: typing.Optional[float]
48+
t_list: List[int]
49+
t_tuple: Tuple[str]
50+
t_dict: Dict[str, str]
51+
t_iterable: Iterable[float]
52+
t_iterator: Iterator[int]
53+
t_generic: G
54+
t_typevar_t: T
55+
t_typevar_s: S
56+
t_list_typevar: List[T]
57+
t_tuple_typevar: Tuple[S]
58+
t_dict_typevar: Dict[T, S]
59+
result = get_variable_annotations(test_func)
60+
expected_result = {'t_any': [Any], 't_union': [Union[str, bytes]], 't_optional': [typing.Optional[float]],
61+
't_list': [List[int]], 't_tuple': [Tuple[str]], 't_dict': [Dict[str, str]],
62+
't_iterable': [Iterable[float]], 't_iterator': [Iterator[int]], 't_generic': [G],
63+
't_typevar_t': [T], 't_typevar_s': [S], 't_list_typevar': [List[T]],
64+
't_tuple_typevar': [Tuple[S]], 't_dict_typevar': [Dict[T, S]]}
65+
self.assertEqual(result, expected_result)
66+
67+
def test_get_variable_annotations_user_types(self):
68+
def test_func():
69+
t_class_a: A
70+
t_class_b: B
71+
t_class_c: C
72+
result = get_variable_annotations(test_func)
73+
expected_result = {'t_class_a': [A], 't_class_b': [B], 't_class_c': [C]}
74+
self.assertEqual(result, expected_result)
75+
76+
def test_get_variable_annotations_non_locals(self):
77+
def foo():
78+
Q = TypeVar('Q')
79+
80+
def bar():
81+
t_typevar: Q
82+
return bar
83+
test_func = foo()
84+
result = get_variable_annotations(test_func)
85+
expected_result = {'t_typevar': [Q]}
86+
self.assertEqual(result, expected_result)
87+
88+
89+
if __name__ == "__main__":
90+
unittest.main()

0 commit comments

Comments
 (0)