Skip to content
This repository was archived by the owner on Feb 2, 2024. It is now read-only.

Commit 0be9331

Browse files
arguments_annotation_default_func_class (#894)
* add getting annotation for func and classes * fix pep8 remarks * from inspect import signature, getfullargspecfix pep8 remarks * another fix pep8 remarks * another fix pep8 remarks * another fix pep8 remarks * small fix * another small fixes * another small fixes * another small fixes * add unit tests and handler for union * fix comments * add function expend_annotations * add tests for function expend_annotations * fix comment * fix comments - improve product_annotations function * add some functions that extend annotations * add function and tests * fix problem with tests * fix comments * rename annotation to signature * fix comments * small changes * fix pep issues
2 parents e6c4aca + d91d980 commit 0be9331

File tree

3 files changed

+346
-0
lines changed

3 files changed

+346
-0
lines changed

numba_typing/tests/__init__.py

Whitespace-only changes.
Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
import unittest
2+
import type_annotations
3+
from typing import Union, Dict, List, TypeVar
4+
from collections import namedtuple
5+
6+
7+
def check_equal(result, expected):
8+
if len(result) != len(expected):
9+
return False
10+
for sig in result:
11+
if sig not in expected:
12+
return False
13+
return True
14+
15+
16+
def check_equal_annotations(result, expected):
17+
18+
if len(result.parameters) != len(expected.parameters):
19+
return False
20+
21+
for sig in result.parameters:
22+
if sig not in expected.parameters:
23+
return False
24+
25+
if result.defaults != expected.defaults:
26+
return False
27+
28+
return True
29+
30+
31+
class TestTypeAnnotations(unittest.TestCase):
32+
33+
def test_get_func_annotations_exceptions(self):
34+
35+
def foo(a: int, b, c: str = "string"):
36+
pass
37+
with self.assertRaises(SyntaxError) as raises:
38+
type_annotations.get_func_annotations(foo)
39+
self.assertIn('No annotation for parameter b', str(raises.exception))
40+
41+
def test_get_cls_annotations(self):
42+
class TestClass(object):
43+
x: int = 3
44+
y: str = "string"
45+
46+
def __init__(self, x, y):
47+
self.x = x
48+
self.y = y
49+
50+
result = type_annotations.get_cls_annotations(TestClass)
51+
expected = ({'x': [int], 'y': [str]}, {})
52+
self.assertEqual(result, expected)
53+
54+
def test_get_func_annotations(self):
55+
56+
def func_one(a: int, b: Union[int, float], c: str):
57+
pass
58+
59+
def func_two(a: int = 2, b: str = "string", c: List[int] = [1, 2, 3]):
60+
pass
61+
62+
def func_three(a: Dict[int, str], b: str = "string", c: int = 1):
63+
pass
64+
65+
expected_results = {
66+
func_one: ({'a': [int], 'b': [int, float], 'c': [str]}, {}),
67+
func_two: ({'a': [int], 'b': [str], 'c': [List[int]]}, {'a': 2, 'b': 'string', 'c': [1, 2, 3]}),
68+
func_three: ({'a': [Dict[int, str]], 'b': [str], 'c': [int]}, {'b': 'string', 'c': 1}),
69+
}
70+
for f, expected in expected_results.items():
71+
with self.subTest(func=f.__name__):
72+
self.assertEqual(type_annotations.get_func_annotations(f), expected)
73+
74+
def test_convert_to_sig_list(self):
75+
T = TypeVar('T', int, str)
76+
S = TypeVar('S', float, str)
77+
annotations = [{'a': [int], 'b': [int, float], 'c': [S]},
78+
{'a': [int], 'b': [T], 'c': [S]},
79+
{'a': [int, str], 'b': [int, float], 'c': [S]}]
80+
81+
expected = [[{'a': int, 'b': int, 'c': S},
82+
{'a': int, 'b': float, 'c': S}],
83+
[{'a': int, 'b': T, 'c': S}],
84+
[{'a': int, 'b': int, 'c': S},
85+
{'a': int, 'b': float, 'c': S},
86+
{'a': str, 'b': int, 'c': S},
87+
{'a': str, 'b': float, 'c': S}]]
88+
89+
for i in range(len(annotations)):
90+
with self.subTest(annotations=i):
91+
self.assertEqual(type_annotations.convert_to_sig_list(annotations[i]), expected[i])
92+
93+
def test_get_typevars(self):
94+
T = TypeVar('T', int, str)
95+
S = TypeVar('S', float, str)
96+
types = [List[T], Dict[T, S], int, T, List[List[T]]]
97+
98+
expected = [{T}, {T, S}, set(), {T}, {T}]
99+
100+
for i in range(len(types)):
101+
with self.subTest(types=i):
102+
self.assertEqual(type_annotations.get_typevars(types[i]), expected[i])
103+
104+
def test_add_vals_to_signature(self):
105+
signature = [{'a': Dict[float, int], 'b': int},
106+
{'a': Dict[str, int], 'b': int},
107+
{'a': Dict[float, str], 'b': int},
108+
{'a': Dict[str, str], 'b': int}]
109+
vals = {'a': {'name': 3}, 'b': 3}
110+
111+
expected = type_annotations.Signature(parameters=[{'a': Dict[float, int], 'b': int},
112+
{'a': Dict[str, int], 'b': int},
113+
{'a': Dict[float, str], 'b': int},
114+
{'a': Dict[str, str], 'b': int}],
115+
defaults={'a': {'name': 3}, 'b': 3})
116+
117+
result = type_annotations.add_vals_to_signature(signature, vals)
118+
119+
self.assertEqual(result, expected)
120+
121+
def test_replace_typevar(self):
122+
T = TypeVar('T', int, str)
123+
S = TypeVar('S', float, str)
124+
125+
types = [List[List[T]], Dict[T, S], T]
126+
expected = [List[List[int]], Dict[int, S], int]
127+
128+
for i in range(len(types)):
129+
with self.subTest(types=i):
130+
self.assertEqual(type_annotations.replace_typevar(types[i], T, int), expected[i])
131+
132+
def test_get_internal_typevars(self):
133+
134+
T = TypeVar('T', int, str)
135+
S = TypeVar('S', float, bool)
136+
signature = {'a': T, 'b': Dict[T, S]}
137+
expected = [{'a': int, 'b': Dict[int, float]},
138+
{'a': int, 'b': Dict[int, bool]},
139+
{'a': str, 'b': Dict[str, float]},
140+
{'a': str, 'b': Dict[str, bool]}]
141+
142+
result = type_annotations.get_internal_typevars(signature)
143+
144+
self.assertTrue(check_equal(result, expected))
145+
146+
def test_update_sig(self):
147+
T = TypeVar('T', int, str)
148+
S = TypeVar('S', float, bool)
149+
150+
sig = {'a': T, 'b': Dict[T, S]}
151+
expected = [{'a': T, 'b': Dict[T, float]},
152+
{'a': T, 'b': Dict[T, bool]}]
153+
result = type_annotations.update_sig(sig, S)
154+
155+
self.assertEqual(result, expected)
156+
157+
def test_expand_typevars(self):
158+
T = TypeVar('T', int, str)
159+
S = TypeVar('S', float, bool)
160+
161+
sig = {'a': T, 'b': Dict[T, S], 'c': int}
162+
unique_typevars = {T, S}
163+
expected = [{'a': int, 'b': Dict[int, float], 'c': int},
164+
{'a': int, 'b': Dict[int, bool], 'c': int},
165+
{'a': str, 'b': Dict[str, float], 'c': int},
166+
{'a': str, 'b': Dict[str, bool], 'c': int}]
167+
168+
result = type_annotations.expand_typevars(sig, unique_typevars)
169+
170+
self.assertTrue(check_equal(result, expected))
171+
172+
def test_product_annotations(self):
173+
174+
T = TypeVar('T', int, str)
175+
S = TypeVar('S', float, bool)
176+
177+
annotations = ({'a': [T], 'b': [Dict[T, S]],
178+
'c': [T, bool], 'd': [int]}, {'d': 3})
179+
180+
expected = type_annotations.Signature(parameters=[{'a': int, 'b': Dict[int, float], 'c': int, 'd': int},
181+
{'a': str, 'b': Dict[str, float], 'c': str, 'd': int},
182+
{'a': int, 'b': Dict[int, bool], 'c': int, 'd': int},
183+
{'a': str, 'b': Dict[str, bool], 'c': str, 'd': int},
184+
{'a': int, 'b': Dict[int, float], 'c': bool, 'd': int},
185+
{'a': str, 'b': Dict[str, float], 'c': bool, 'd': int},
186+
{'a': int, 'b': Dict[int, bool], 'c': bool, 'd': int},
187+
{'a': str, 'b': Dict[str, bool], 'c': bool, 'd': int}],
188+
defaults={'d': 3})
189+
190+
result = type_annotations.product_annotations(annotations)
191+
192+
self.assertTrue(check_equal_annotations(result, expected))
193+
194+
195+
if __name__ == '__main__':
196+
unittest.main()

numba_typing/type_annotations.py

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
from inspect import signature
2+
from typing import get_type_hints, Union, TypeVar, _GenericAlias
3+
from itertools import product
4+
from copy import deepcopy
5+
from collections import namedtuple
6+
7+
Signature = namedtuple('Signature', ['parameters', 'defaults'])
8+
9+
10+
def get_func_annotations(func):
11+
"""Get annotations and default values of the fuction parameters."""
12+
sig = signature(func)
13+
annotations = {}
14+
defaults = {}
15+
16+
for name, param in sig.parameters.items():
17+
if param.annotation == sig.empty:
18+
raise SyntaxError(f'No annotation for parameter {name}')
19+
20+
annotations[name] = get_annotation_types(param.annotation)
21+
if param.default != sig.empty:
22+
defaults[name] = param.default
23+
24+
return annotations, defaults
25+
26+
27+
def get_cls_annotations(cls):
28+
"""Get annotations of class attributes."""
29+
annotations = get_type_hints(cls)
30+
for x in annotations:
31+
annotations[x] = get_annotation_types(annotations[x])
32+
return annotations, {}
33+
34+
35+
def get_annotation_types(annotation):
36+
"""Get types of passed annotation."""
37+
try:
38+
if annotation.__origin__ is Union:
39+
return list(annotation.__args__)
40+
except AttributeError:
41+
pass
42+
43+
return [annotation, ]
44+
45+
46+
def product_annotations(annotations):
47+
'''Get all variants of annotations.'''
48+
types, vals = annotations
49+
list_of_sig = convert_to_sig_list(types)
50+
signature = []
51+
52+
for sig in list_of_sig:
53+
signature.extend(get_internal_typevars(sig))
54+
55+
return add_vals_to_signature(signature, vals)
56+
57+
58+
def add_vals_to_signature(sign, vals):
59+
'''Add default values ​​to all signatures'''
60+
signature = Signature(sign, vals)
61+
return signature
62+
63+
64+
def convert_to_sig_list(types):
65+
'''Expands all Unions'''
66+
types_product = list(product(*types.values()))
67+
names = list(types)
68+
result = []
69+
70+
for sig in types_product:
71+
sig_result = {}
72+
for i in range(len(sig)):
73+
sig_result[names[i]] = sig[i]
74+
result.append(sig_result)
75+
76+
return result
77+
78+
79+
def get_internal_typevars(sig):
80+
'''Get unique typevars in signature'''
81+
unique_typevars = set()
82+
for typ in sig.values():
83+
unique_typevars.update(get_typevars(typ))
84+
85+
if len(unique_typevars) == 0:
86+
return [sig]
87+
88+
return expand_typevars(sig, unique_typevars)
89+
90+
91+
def get_typevars(type):
92+
'''Get unique typevars in type (container)'''
93+
if isinstance(type, TypeVar) and type.__constraints__:
94+
return {type, }
95+
elif isinstance(type, _GenericAlias):
96+
result = set()
97+
for arg in type.__args__:
98+
result.update(get_typevars(arg))
99+
return result
100+
101+
return set()
102+
103+
104+
def expand_typevars(sig, unique_typevars):
105+
'''Exstend all Typevars in signature'''
106+
result = [sig]
107+
108+
for typevar in unique_typevars:
109+
temp_result = []
110+
for temp_sig in result:
111+
temp_result.extend(update_sig(temp_sig, typevar))
112+
result = temp_result
113+
114+
return result
115+
116+
117+
def update_sig(temp_sig, typevar):
118+
'''Expand one typevar'''
119+
result = []
120+
for constr_type in typevar.__constraints__:
121+
sig = {}
122+
for name, typ in temp_sig.items():
123+
if typevar in get_typevars(typ):
124+
sig[name] = replace_typevar(typ, typevar, constr_type)
125+
else:
126+
sig[name] = typ
127+
128+
result.append(sig)
129+
130+
return result
131+
132+
133+
def replace_typevar(typ, typevar, final_typ):
134+
'''Replace typevar with type in container
135+
For example:
136+
# typ = Dict[T, V]
137+
# typevar = T(int, str)
138+
# final_typ = int
139+
'''
140+
141+
if typ == typevar:
142+
return final_typ
143+
elif isinstance(typ, _GenericAlias):
144+
result = list()
145+
for arg in typ.__args__:
146+
result.append(replace_typevar(arg, typevar, final_typ))
147+
result_type = typ.copy_with(tuple(result))
148+
return result_type
149+
150+
return typ

0 commit comments

Comments
 (0)