Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lazy dispatch #100

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions docs/source/resolution.rst
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,19 @@ For example, here's a function that takes a ``float`` followed by any number
>>> f(2.0, '4', 6, 8)
20.0

Lazy Dispatch
-------------

You may need to refer to your own class while defining it. Just use its name as
a string and ``multipledispatch`` will resolve a name to a class during runtime

.. code::

class MyInteger(int):
@dispatch('MyInteger')
def add(self, x):
return self + x

Ambiguities
-----------

Expand Down
5 changes: 5 additions & 0 deletions multipledispatch/conflict.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import itertools

from .utils import _toposort, groupby
from .variadic import isvariadic

Expand All @@ -8,6 +10,9 @@ class AmbiguityWarning(Warning):

def supercedes(a, b):
""" A is consistent and strictly more specific than B """
if any(isinstance(x, str) for x in itertools.chain(a, b)):
# skip due to lazy types
return False
if len(a) < len(b):
# only case is if a is empty and b is variadic
return not a and len(b) == 1 and isvariadic(b[-1])
Expand Down
48 changes: 42 additions & 6 deletions multipledispatch/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,14 +117,16 @@ class Dispatcher(object):
>>> f(3.0)
2.0
"""
__slots__ = '__name__', 'name', 'funcs', '_ordering', '_cache', 'doc'
__slots__ = '__name__', 'name', 'funcs', '_ordering', '_cache', 'doc', \
'_lazy'

def __init__(self, name, doc=None):
self.name = self.__name__ = name
self.funcs = {}
self.doc = doc

self._cache = {}
self._lazy = False

def register(self, *types, **kwargs):
""" register dispatcher with new implementation
Expand Down Expand Up @@ -214,9 +216,8 @@ def add(self, signature, func):
return

new_signature = []

for index, typ in enumerate(signature, start=1):
if not isinstance(typ, (type, list)):
if not isinstance(typ, (type, list, str)):
str_sig = ', '.join(c.__name__ if isinstance(c, type)
else str(c) for c in signature)
raise TypeError("Tried to dispatch on non-type: %s\n"
Expand All @@ -237,9 +238,12 @@ def add(self, signature, func):
'To use a variadic union type place the desired types '
'inside of a tuple, e.g., [(int, str)]'
)
new_signature.append(Variadic[typ[0]])
else:
new_signature.append(typ)
typ = Variadic[typ[0]]

if isinstance(typ, str):
self._lazy = True

new_signature.append(typ)

self.funcs[tuple(new_signature)] = func
self._cache.clear()
Expand All @@ -264,6 +268,9 @@ def reorder(self, on_ambiguity=ambiguity_warn):
return od

def __call__(self, *args, **kwargs):
if self._lazy:
self._unlazy()

types = tuple([type(arg) for arg in args])
try:
func = self._cache[types]
Expand Down Expand Up @@ -359,6 +366,7 @@ def __setstate__(self, d):
self.funcs = d['funcs']
self._ordering = ordering(self.funcs)
self._cache = dict()
self._lazy = any(isinstance(t, str) for t in itl.chain(*d['funcs']))

@property
def __doc__(self):
Expand Down Expand Up @@ -400,6 +408,31 @@ def source(self, *args, **kwargs):
""" Print source code for the function corresponding to inputs """
print(self._source(*args))

def _unlazy(self):
funcs = {}
for signature, func in self.funcs.items():
new_signature = []
for typ in signature:
if isinstance(typ, str):
for frame_info in inspect.stack():
frame = frame_info[0]
scope = dict(frame.f_globals)
scope.update(frame.f_locals)
if typ in scope:
typ = scope[typ]
break
else:
raise NameError("name '%s' is not defined" % typ)
new_signature.append(typ)

new_signature = tuple(new_signature)
funcs[new_signature] = func

self.funcs = funcs
self.reorder()

self._lazy = False


def source(func):
s = 'File: %s\n\n' % inspect.getsourcefile(func)
Expand Down Expand Up @@ -427,6 +460,9 @@ def __get__(self, instance, owner):
return self

def __call__(self, *args, **kwargs):
if self._lazy:
self._unlazy()

types = tuple([type(arg) for arg in args])
func = self.dispatch(*types)
if not func:
Expand Down
72 changes: 72 additions & 0 deletions multipledispatch/tests/test_dispatcher.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@

import warnings

from multipledispatch import dispatch
from multipledispatch.dispatcher import (Dispatcher, MDNotImplementedError,
MethodDispatcher)
from multipledispatch.conflict import ambiguities
Expand Down Expand Up @@ -421,3 +422,74 @@ def _3(*objects):
assert f('a', ['a']) == 2
assert f(1) == 3
assert f() == 3


def test_lazy_methods():
class A(object):
@dispatch(int)
def get(self, _):
return 'int'

@dispatch('A')
def get(self, _):
"""Self reference"""
return 'A'

@dispatch('B')
def get(self, _):
"""Yet undeclared type"""
return 'B'

class B(object):
pass

class C(A):
@dispatch('D')
def get(self, _):
"""Non-existent type"""
return 'D'

a = A()
b = B()
c = C()

assert a.get(1) == 'int'
assert a.get(a) == 'A'
assert a.get(b) == 'B'
assert raises(NameError, lambda: c.get(1))


def test_lazy_functions():
f = Dispatcher('f')
f.add((int,), inc)
f.add(('Int',), dec)

assert raises(NameError, lambda: f(1))

class Int(int):
pass

assert f(1) == 2
assert f(Int(1)) == 0


def test_lazy_serializable():
f = Dispatcher('f')
f.add((int,), inc)
f.add(('Int',), dec)

import pickle
assert isinstance(pickle.dumps(f), (str, bytes))

g = pickle.loads(pickle.dumps(f))

assert f.funcs == g.funcs
assert f._lazy == g._lazy

assert raises(NameError, lambda: f(1))

class Int(int):
pass

assert g(1) == 2
assert g(Int(1)) == 0
21 changes: 21 additions & 0 deletions multipledispatch/tests/test_dispatcher_3only.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from multipledispatch import dispatch
from multipledispatch.dispatcher import Dispatcher
from multipledispatch.utils import raises


def test_function_annotation_register():
Expand Down Expand Up @@ -92,3 +93,23 @@ def inc(x: int):

assert inc(1) == 2
assert inc(1.0) == 0.0


def test_lazy_annotations():
f = Dispatcher('f')

@f.register()
def inc(x: int):
return x + 1

@f.register()
def dec(x: 'Int'):
return x - 1

assert raises(NameError, lambda: f(1))

class Int(int):
pass

assert f(1) == 2
assert f(Int(1)) == 0