diff --git a/docs/source/resolution.rst b/docs/source/resolution.rst index 65f516c..ab971d7 100644 --- a/docs/source/resolution.rst +++ b/docs/source/resolution.rst @@ -161,6 +161,19 @@ For example, here's a function that takes a ``float`` followed by any number >>> f(2.0, '4', 6, 8) 20.0 +Lazy Dispatch +------------- + +You may need to refer to your own class while defining it. Just use its name as +a string and ``multipledispatch`` will resolve a name to a class during runtime + +.. code:: + + class MyInteger(int): + @dispatch('MyInteger') + def add(self, x): + return self + x + Ambiguities ----------- diff --git a/multipledispatch/conflict.py b/multipledispatch/conflict.py index 5b5e942..10dccf8 100644 --- a/multipledispatch/conflict.py +++ b/multipledispatch/conflict.py @@ -1,3 +1,5 @@ +import itertools + from .utils import _toposort, groupby from .variadic import isvariadic @@ -8,6 +10,9 @@ class AmbiguityWarning(Warning): def supercedes(a, b): """ A is consistent and strictly more specific than B """ + if any(isinstance(x, str) for x in itertools.chain(a, b)): + # skip due to lazy types + return False if len(a) < len(b): # only case is if a is empty and b is variadic return not a and len(b) == 1 and isvariadic(b[-1]) diff --git a/multipledispatch/dispatcher.py b/multipledispatch/dispatcher.py index 7568595..700f127 100644 --- a/multipledispatch/dispatcher.py +++ b/multipledispatch/dispatcher.py @@ -117,7 +117,8 @@ class Dispatcher(object): >>> f(3.0) 2.0 """ - __slots__ = '__name__', 'name', 'funcs', '_ordering', '_cache', 'doc' + __slots__ = '__name__', 'name', 'funcs', '_ordering', '_cache', 'doc', \ + '_lazy' def __init__(self, name, doc=None): self.name = self.__name__ = name @@ -125,6 +126,7 @@ def __init__(self, name, doc=None): self.doc = doc self._cache = {} + self._lazy = False def register(self, *types, **kwargs): """ register dispatcher with new implementation @@ -214,9 +216,8 @@ def add(self, signature, func): return new_signature = [] - for index, typ in enumerate(signature, start=1): - if not isinstance(typ, (type, list)): + if not isinstance(typ, (type, list, str)): str_sig = ', '.join(c.__name__ if isinstance(c, type) else str(c) for c in signature) raise TypeError("Tried to dispatch on non-type: %s\n" @@ -237,9 +238,12 @@ def add(self, signature, func): 'To use a variadic union type place the desired types ' 'inside of a tuple, e.g., [(int, str)]' ) - new_signature.append(Variadic[typ[0]]) - else: - new_signature.append(typ) + typ = Variadic[typ[0]] + + if isinstance(typ, str): + self._lazy = True + + new_signature.append(typ) self.funcs[tuple(new_signature)] = func self._cache.clear() @@ -264,6 +268,9 @@ def reorder(self, on_ambiguity=ambiguity_warn): return od def __call__(self, *args, **kwargs): + if self._lazy: + self._unlazy() + types = tuple([type(arg) for arg in args]) try: func = self._cache[types] @@ -359,6 +366,7 @@ def __setstate__(self, d): self.funcs = d['funcs'] self._ordering = ordering(self.funcs) self._cache = dict() + self._lazy = any(isinstance(t, str) for t in itl.chain(*d['funcs'])) @property def __doc__(self): @@ -400,6 +408,31 @@ def source(self, *args, **kwargs): """ Print source code for the function corresponding to inputs """ print(self._source(*args)) + def _unlazy(self): + funcs = {} + for signature, func in self.funcs.items(): + new_signature = [] + for typ in signature: + if isinstance(typ, str): + for frame_info in inspect.stack(): + frame = frame_info[0] + scope = dict(frame.f_globals) + scope.update(frame.f_locals) + if typ in scope: + typ = scope[typ] + break + else: + raise NameError("name '%s' is not defined" % typ) + new_signature.append(typ) + + new_signature = tuple(new_signature) + funcs[new_signature] = func + + self.funcs = funcs + self.reorder() + + self._lazy = False + def source(func): s = 'File: %s\n\n' % inspect.getsourcefile(func) @@ -427,6 +460,9 @@ def __get__(self, instance, owner): return self def __call__(self, *args, **kwargs): + if self._lazy: + self._unlazy() + types = tuple([type(arg) for arg in args]) func = self.dispatch(*types) if not func: diff --git a/multipledispatch/tests/test_dispatcher.py b/multipledispatch/tests/test_dispatcher.py index f07544d..5d23515 100644 --- a/multipledispatch/tests/test_dispatcher.py +++ b/multipledispatch/tests/test_dispatcher.py @@ -1,6 +1,7 @@ import warnings +from multipledispatch import dispatch from multipledispatch.dispatcher import (Dispatcher, MDNotImplementedError, MethodDispatcher) from multipledispatch.conflict import ambiguities @@ -421,3 +422,74 @@ def _3(*objects): assert f('a', ['a']) == 2 assert f(1) == 3 assert f() == 3 + + +def test_lazy_methods(): + class A(object): + @dispatch(int) + def get(self, _): + return 'int' + + @dispatch('A') + def get(self, _): + """Self reference""" + return 'A' + + @dispatch('B') + def get(self, _): + """Yet undeclared type""" + return 'B' + + class B(object): + pass + + class C(A): + @dispatch('D') + def get(self, _): + """Non-existent type""" + return 'D' + + a = A() + b = B() + c = C() + + assert a.get(1) == 'int' + assert a.get(a) == 'A' + assert a.get(b) == 'B' + assert raises(NameError, lambda: c.get(1)) + + +def test_lazy_functions(): + f = Dispatcher('f') + f.add((int,), inc) + f.add(('Int',), dec) + + assert raises(NameError, lambda: f(1)) + + class Int(int): + pass + + assert f(1) == 2 + assert f(Int(1)) == 0 + + +def test_lazy_serializable(): + f = Dispatcher('f') + f.add((int,), inc) + f.add(('Int',), dec) + + import pickle + assert isinstance(pickle.dumps(f), (str, bytes)) + + g = pickle.loads(pickle.dumps(f)) + + assert f.funcs == g.funcs + assert f._lazy == g._lazy + + assert raises(NameError, lambda: f(1)) + + class Int(int): + pass + + assert g(1) == 2 + assert g(Int(1)) == 0 diff --git a/multipledispatch/tests/test_dispatcher_3only.py b/multipledispatch/tests/test_dispatcher_3only.py index b041450..bc4f224 100644 --- a/multipledispatch/tests/test_dispatcher_3only.py +++ b/multipledispatch/tests/test_dispatcher_3only.py @@ -4,6 +4,7 @@ from multipledispatch import dispatch from multipledispatch.dispatcher import Dispatcher +from multipledispatch.utils import raises def test_function_annotation_register(): @@ -92,3 +93,23 @@ def inc(x: int): assert inc(1) == 2 assert inc(1.0) == 0.0 + + +def test_lazy_annotations(): + f = Dispatcher('f') + + @f.register() + def inc(x: int): + return x + 1 + + @f.register() + def dec(x: 'Int'): + return x - 1 + + assert raises(NameError, lambda: f(1)) + + class Int(int): + pass + + assert f(1) == 2 + assert f(Int(1)) == 0