-
Notifications
You must be signed in to change notification settings - Fork 61
automatic generation of type checks for overload #911
base: numba_typing
Are you sure you want to change the base?
automatic generation of type checks for overload #911
Conversation
Hello @vlad-perevezentsev! Thanks for updating this PR. We checked the lines you've touched for PEP 8 issues, and found: There are currently no PEP 8 issues detected in this Pull Request. Cheers! 🍻 Comment last updated at 2020-09-29 14:36:35 UTC |
numba_typing/overload_list.py
Outdated
return self._types_dict[p_type](self, p_type, n_type) | ||
return self._types_dict[p_type](n_type) | ||
except KeyError: | ||
print((f'A check for the {p_type} was not found. {n_type}')) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So we will (implicitly) return None
from match
in this case, is this intended behavior? Please add explicit return or throw
numba_typing/overload_list.py
Outdated
|
||
def choose_func_by_sig(sig_list, values_dict, defaults_dict={}): | ||
checker = TypeChecker() | ||
for sig in sig_list: # sig = (Signature,func) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of accessing sig[0]
and sig[1]
each time you can write for sig, func in sig_list:
numba_typing/overload_list.py
Outdated
import numpy | ||
import numba | ||
from numba import types | ||
from numba import typeof | ||
from numba.extending import overload | ||
from type_annotations import product_annotations, get_func_annotations | ||
from numba import njit | ||
import typing | ||
from numba import NumbaDeprecationWarning, NumbaPendingDeprecationWarning | ||
import warnings | ||
from numba.typed import List, Dict | ||
from inspect import getfullargspec |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some unused imports here, like numpy
, typeof
and njit
, please clean them up.
numba_typing/overload_list.py
Outdated
warnings.simplefilter('ignore', category=NumbaDeprecationWarning) | ||
warnings.simplefilter('ignore', category=NumbaPendingDeprecationWarning) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this needed?
numba_typing/overload_list.py
Outdated
|
||
def check_tuple_type(self, p_type, n_type): | ||
res = False | ||
if isinstance(n_type, types.Tuple): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of using n_type.key
you could use n_type.types
which is defined for both Tuple
and UniTuple
, so you can unite both branches:
Something like:
if not isinstance(n_type, types.Tuple, types.UniTuple):
return False
for p_val, n_val in zip(p_type.__args__, n_type.types):
if not self.match(p_val, n_val):
return False
return True
And btw you need to check that size of p_type.__args__
and n_type.types
are the same.
And I believe you don't have tests for the case when they are different.
numba_typing/overload_list.py
Outdated
return self._types_dict[p_type](self, p_type, n_type) | ||
return self._types_dict[p_type](n_type) | ||
except KeyError: | ||
print((f'A check for the {p_type} was not found.')) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should rise an exception
numba_typing/overload_list.py
Outdated
|
||
def match(self, p_type, n_type): | ||
try: | ||
if p_type == typing.Any: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's do it like this:
if p_type == typing.Any:
return True
if self._is_generic(p_type):
origin_type = self._get_origin(p_type)
if origin_type == typing.Generic:
return self.match_generic(p_type, n_type)
return self._types_dict[origin_type](self, p_type, n_type)
if isinstance(p_type, typing.TypeVar):
return self.match_typevar(p_type, n_type)
if p_type in (list, tuple):
return self._types_dict[p_type](self, p_type, n_type)
return self._types_dict[p_type](n_type)
numba_typing/overload_list.py
Outdated
return None | ||
|
||
def match_typevar(self, p_type, n_type): | ||
if not self._typevars_dict.get(p_type) and n_type not in self._typevars_dict.values(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do you need condition n_type not in self._typevars_dict.values()
?
if not self._typevars_dict.get(p_type) and n_type not in self._typevars_dict.values(): | ||
self._typevars_dict[p_type] = n_type | ||
return True | ||
return self._typevars_dict.get(p_type) == n_type |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe it should be self.match
. E.g. list
and 'types.List' are synonyms but will fail equality check ('list != types.List').
And I'm assuming you don't have such tests?
numba_typing/overload_list.py
Outdated
def match_generic(self, p_type, n_type): | ||
res = True | ||
for arg in p_type.__args__: | ||
res = res and self.match(arg, n_type) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's doesn't feel right. Do we have any test for this case?
numba_typing/overload_list.py
Outdated
|
||
if sig.defaults.get(name, False): | ||
full_match = full_match and sig.defaults[name] == typ.literal_value | ||
else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need this else
?
numba_typing/overload_list.py
Outdated
|
||
|
||
def check_list_type(self, p_type, n_type): | ||
res = isinstance(n_type, types.List) or isinstance(n_type, types.ListType) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
res = isinstance(n_type, (types.List, types.ListType))
numba_typing/overload_list.py
Outdated
|
||
def check_list_type(self, p_type, n_type): | ||
res = isinstance(n_type, types.List) or isinstance(n_type, types.ListType) | ||
if isinstance(p_type, type): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
isinstance(p_type, (list, typing.List))
?
numba_typing/overload_list.py
Outdated
elif isinstance(p_type, typing.TypeVar): | ||
return self.match_typevar(p_type, n_type) | ||
else: | ||
if p_type in (list, tuple): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't dict be here too?
numba_typing/overload_list.py
Outdated
|
||
class TypeChecker: | ||
|
||
_types_dict = {int: check_int_type, float: check_float_type, bool: check_bool_type, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd prefer this checks to be added using add_type_check
function
This code actually modifies original function globals. This shouldn't happen, please fix. |
numba_typing/overload_list.py
Outdated
TypeChecker.add_type_check(dict, check_dict_type) | ||
|
||
|
||
def choose_func_by_sig(sig_list, values_dict, defaults_dict): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's do it this way:
def choose_func_by_sig(sig_list, values_dict, defaults_dict):
def check_signature(sig_params, types_dict):
checker = TypeChecker()
for name, typ in types_dict.items(): # name,type = 'a',int64
if isinstance(typ, types.Literal):
typ = typ.literal_type
if not checker.match(sig_params[name], typ):
return False
return True
for sig, func in sig_list: # sig = (Signature,func)
for param in sig.parameters: # param = {'a':int,'b':int}
if check_signature(param, values_dict):
return func
return None
numba_typing/test_overload_list.py
Outdated
return typ.__name__ | ||
return typ | ||
|
||
value_keys = ", ".join("{}".format(key) for key in values_dict.keys()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please use f'{key}'
instead. Here and everywhere
numba_typing/test_overload_list.py
Outdated
D_list['qwe'] = List([3, 4, 5]) | ||
str_1 = 'qwe' | ||
str_2 = 'qaz' | ||
test_cases = [('int', [1, 2], {'a': int, 'b': int}), ('float', [1.0, 2.0], {'a': float, 'b': float}), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
{'a': int, 'b': int}
why do you need two parameters of the same type?
What are you actually testing?
numba_typing/test_overload_list.py
Outdated
str_2 = 'qaz' | ||
test_cases = [('int', [1, 2], {'a': int, 'b': int}), ('float', [1.0, 2.0], {'a': float, 'b': float}), | ||
('bool', [True, True], {'a': bool, 'b': bool}), ('str', ['str_1', 'str_2'], {'a': str, 'b': str}), | ||
('list', [[1, 2], [3, 4]], {'a': typing.List[int], 'b':list}), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do you need second list parameter?
numba_typing/test_overload_list.py
Outdated
('tuple', [(1, 2.0), ('3', False)], {'a': typing.Tuple[int, float], 'b':tuple}), | ||
('dict', ['D', 'D_1'], {'a': typing.Dict[str, int], 'b': typing.Dict[int, bool]}), | ||
('union_1', [1, False], {'a': typing.Union[int, str], 'b': typing.Union[float, bool]}), | ||
('union_2', ['str_1', False], {'a': typing.Union[int, str], 'b': typing.Union[float, bool]}), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And why you are not testing second Union parameter?
numba_typing/test_overload_list.py
Outdated
('TypeVar_ListT_DictK_ListT', ['L', 'D_list'], {'a': 'typing.List[T]', | ||
'b': 'typing.Dict[K, typing.List[T]]'})] | ||
|
||
test_cases_default = [('int_defaults', [1], {'a': int}, {'b': 0}), ('float_defaults', [1.0], {'a': float}, {'b': 0.0}), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And what about both - annotation and default value?
numba_typing/test_overload_list.py
Outdated
('TypeVar_ListT_T', ['L', 5], {'a': 'typing.List[T]', 'b': 'T'}), | ||
('TypeVar_ListT_DictKT', ['L', 'D'], {'a': 'typing.List[T]', 'b': 'typing.Dict[K, T]'}), | ||
('TypeVar_ListT_DictK_ListT', ['L', 'D_list'], {'a': 'typing.List[T]', | ||
'b': 'typing.Dict[K, typing.List[T]]'})] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are where any tests for TypeVar with specified types restriction?
numba_typing/test_overload_list.py
Outdated
('union_2', ['str_1', False], {'a': typing.Union[int, str], 'b': typing.Union[float, bool]}), | ||
('nested_list', ['L_int', 'L_float'], {'a': typing.List[typing.List[int]], | ||
'b': typing.List[typing.List[typing.List[float]]]}), | ||
('TypeVar_TT', ['L_f', [3.0, 4.0]], {'a': 'T', 'b': 'T'}), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any negative tests for such case?
numba_typing/test_overload_list.py
Outdated
('tuple_defaults', [(1, 2)], {'a': tuple}, {'b': (0, 0)})] | ||
|
||
|
||
for name, val, annotation in test_cases: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd prefer you to group all cases into 4-5 tests with subtests. Something like this:
def test_common_types():
test_cases = [({'a': 0}, {'a': int}),
({'a': 0.}, {'a': float}),
...]
for case in test_cases:
with self.Subtest(case=case):
run_test(case)
numba_typing/overload_list.py
Outdated
result = choose_func_by_sig(sig_list, values_dict) | ||
|
||
if result is None: | ||
raise TypeError(f'Unsupported types a={a}, b={b}') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks both a
and b
are undefined.
numba_typing/overload_list.py
Outdated
|
||
for sig, _ in list_signature: | ||
for param in sig.parameters: | ||
if len(param) != len(values_dict.items()): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You don't need to call method items()
here.
numba_typing/overload_list.py
Outdated
for name, val in defaults_dict.items(): | ||
if sig_def.get(name) is None: | ||
raise AttributeError(f'{name} does not match the signature of the function passed to overload_list') | ||
if not sig_def[name] == val: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This condition looks pretty strange. Maybe if sig_def[name] != val
?
No description provided.