Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Experiment to use pytypes to add support for python type hints #69

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Open
1 change: 1 addition & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ python:
install:
- pip install coverage
- pip install --upgrade pytest pytest-benchmark
- pip install pytypes

script:
- |
Expand Down
6 changes: 4 additions & 2 deletions multipledispatch/conflict.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
from .utils import _toposort, groupby
from pytypes import is_subtype


class AmbiguityWarning(Warning):
pass


def supercedes(a, b):
""" A is consistent and strictly more specific than B """
return len(a) == len(b) and all(map(issubclass, a, b))
return len(a) == len(b) and all(map(is_subtype, a, b))


def consistent(a, b):
""" It is possible for an argument list to satisfy both A and B """
return (len(a) == len(b) and
all(issubclass(aa, bb) or issubclass(bb, aa)
all(is_subtype(aa, bb) or is_subtype(bb, aa)
for aa, bb in zip(a, b)))


Expand Down
20 changes: 13 additions & 7 deletions multipledispatch/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
import inspect
from .conflict import ordering, ambiguities, super_signature, AmbiguityWarning
from .utils import expand_tuples
import itertools as itl

import itertools as itl
import pytypes
import typing

class MDNotImplementedError(NotImplementedError):
""" A NotImplementedError for multiple dispatch """
Expand Down Expand Up @@ -111,7 +113,7 @@ def get_func_params(cls, func):

@classmethod
def get_func_annotations(cls, func):
""" get annotations of function positional paremeters
""" get annotations of function positional parameters
"""
params = cls.get_func_params(func)
if params:
Expand All @@ -135,13 +137,17 @@ def add(self, signature, func, on_ambiguity=ambiguity_warn):
>>> D = Dispatcher('add')
>>> D.add((int, int), lambda x, y: x + y)
>>> D.add((float, float), lambda x, y: x + y)
>>> D.add((typing.Optional[str], ), lambda x: x)

>>> D(1, 2)
3
>>> D(1, 2.0)
>>> D('1', 2.0)
Traceback (most recent call last):
...
NotImplementedError: Could not find signature for add: <int, float>
NotImplementedError: Could not find signature for add: <str, float>
>>> D('s')
's'
>>> D(None)

When ``add`` detects a warning it calls the ``on_ambiguity`` callback
with a dispatcher/itself, and a set of ambiguous type signature pairs
Expand All @@ -154,7 +160,7 @@ def add(self, signature, func, on_ambiguity=ambiguity_warn):
signature = annotations

# Handle union types
if any(isinstance(typ, tuple) for typ in signature):
if any(isinstance(typ, tuple) or pytypes.is_Union(typ) for typ in signature):
for typs in expand_tuples(signature):
self.add(typs, func, on_ambiguity)
return
Expand Down Expand Up @@ -182,7 +188,7 @@ def reorder(self, on_ambiguity=ambiguity_warn):
_unresolved_dispatchers.add(self)

def __call__(self, *args, **kwargs):
types = tuple([type(arg) for arg in args])
types = tuple([pytypes.deep_type(arg) for arg in args])
try:
func = self._cache[types]
except KeyError:
Expand Down Expand Up @@ -244,7 +250,7 @@ def dispatch(self, *types):
def dispatch_iter(self, *types):
n = len(types)
for signature in self.ordering:
if len(signature) == n and all(map(issubclass, types, signature)):
if len(signature) == n and all(map(pytypes.is_subtype, types, signature)):
result = self.funcs[signature]
yield result

Expand Down
16 changes: 16 additions & 0 deletions multipledispatch/tests/test_dispatcher_3only.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from multipledispatch import dispatch
from multipledispatch.dispatcher import Dispatcher
import typing


def test_function_annotation_register():
Expand All @@ -30,8 +31,23 @@ def inc(x: int):
def inc(x: float):
return x - 1

@dispatch()
def inc(x: typing.Optional[str]):
return x

@dispatch()
def inc(x: typing.List[int]):
return x[0] * 4

@dispatch()
def inc(x: typing.List[str]):
return x[0] + 'b'

assert inc(1) == 2
assert inc(1.0) == 0.0
assert inc('a') == 'a'
assert inc([8]) == 32
assert inc(['a']) == 'ab'


def test_function_annotation_dispatch_custom_namespace():
Expand Down
22 changes: 17 additions & 5 deletions multipledispatch/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@

import pytypes
import typing


def raises(err, lamda):
try:
lamda()
Expand All @@ -14,15 +19,22 @@ def expand_tuples(L):

>>> expand_tuples([1, 2])
[(1, 2)]

>>> expand_tuples([1, typing.Optional[str]]) #doctest: +ELLIPSIS
[(1, <... 'str'>), (1, <... 'NoneType'>)]
"""
if not L:
return [()]
elif not isinstance(L[0], tuple):
rest = expand_tuples(L[1:])
return [(L[0],) + t for t in rest]
else:
rest = expand_tuples(L[1:])
return [(item,) + t for t in rest for item in L[0]]
if pytypes.is_Union(L[0]):
rest = expand_tuples(L[1:])
return [(item,) + t for t in rest for item in pytypes.get_Union_params(L[0])]
elif not pytypes.is_of_type(L[0], tuple):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

somehow this doesn't actually get hit when L[0] is not a tuple:

(Pdb) p L[0]
<type 'numpy.dtype'>
(Pdb) p pytypes.is_of_type(L[0], tuple)
True

This breaks importing datashape.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The root cause of this is actually more concerning than this one bug. Why doesn't pytypes.is_of_type work correctly for type objects?

rest = expand_tuples(L[1:])
return [(L[0],) + t for t in rest]
else:
rest = expand_tuples(L[1:])
return [(item,) + t for t in rest for item in L[0]]


# Taken from theano/theano/gof/sched.py
Expand Down