Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for bound methods at runtime #5

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
*.pyc
*.egg-info
*.idea
67 changes: 27 additions & 40 deletions multipledispatch/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,13 @@ class Dispatcher(object):
>>> f(3.0)
2.0
"""
__slots__ = 'name', 'funcs', 'ordering', '_cache'
__slots__ = 'name', 'funcs', 'ordering', '_cache', 'instance'

def __init__(self, name):
self.name = name
self.funcs = dict()
self._cache = dict()
self.instance = None

def add(self, signature, func):
""" Add new types/method pair to dispatcher
Expand All @@ -56,7 +57,15 @@ def add(self, signature, func):
def __call__(self, *args, **kwargs):
types = tuple([type(arg) for arg in args])
func = self.resolve(types)
return func(*args, **kwargs)
# check if the Dispatcher has an instance attribute.
# If so, it is being called as a bound method
if self.instance is not None:
result = func(self.instance, *args, **kwargs)
# set instance to None to reset the Dispatcher
self.instance = None
return result
else:
return func(*args, **kwargs)

def __str__(self):
return "<dispatched %s>" % self.name
Expand Down Expand Up @@ -100,26 +109,14 @@ def resolve(self, types):
return result
raise NotImplementedError()


class MethodDispatcher(Dispatcher):
""" Dispatch methods based on type signature

See Also:
Dispatcher
"""
def __get__(self, instance, owner):
self.obj = instance
self.cls = owner
def __get__(self, instance, typ):
"""
This is only called if the Dispatcher is decorating a method. In that
case, the instance attribute is set.
"""
self.instance = instance
return self

def __call__(self, *args, **kwargs):
types = tuple([type(arg) for arg in args])
func = self.resolve(types)
return func(self.obj, *args, **kwargs)


global_namespace = dict()


def dispatch(*types, **kwargs):
""" Dispatch function on the types of the inputs
Expand Down Expand Up @@ -164,35 +161,25 @@ def dispatch(*types, **kwargs):
... def __init__(self, datum):
... self.data = [datum]
"""
namespace = kwargs.get('namespace', global_namespace)
namespace = kwargs.get('namespace', None)
types = tuple(types)

def _(func):
name = func.__name__

if ismethod(func):
dispatcher = inspect.currentframe().f_back.f_locals.get(name,
MethodDispatcher(name))
else:
frame = inspect.currentframe()
if namespace is not None:
if name not in namespace:
namespace[name] = Dispatcher(name)
namespace[name] = frame.f_locals.get(
name, Dispatcher(name))
dispatcher = namespace[name]

else:
dispatcher = frame.f_back.f_locals.get(name, Dispatcher(name))
del frame
for typs in expand_tuples(types):
dispatcher.add(typs, func)
return dispatcher
return _


def ismethod(func):
""" Is func a method?

Note that this has to work as the method is defined but before the class is
defined. At this stage methods look like functions.
"""
spec = inspect.getargspec(func)
return spec and spec.args and spec.args[0] == 'self'


def expand_tuples(L):
"""

Expand Down Expand Up @@ -230,4 +217,4 @@ def warning_text(name, amb):
text += "\n\nConsider making the following additions:\n\n"
text += '\n\n'.join(['@dispatch(' + str_signature(super_signature(s))
+ ')\ndef %s(...)'%name for s in amb])
return text
return text
55 changes: 44 additions & 11 deletions multipledispatch/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,36 +171,69 @@ def q(x):
def test_methods():
class Foo(object):
@dispatch(float)
def f(self, x):
def methodf(self, x):
return x - 1

@dispatch(int)
def f(self, x):
def methodf(self, x):
return x + 1

@dispatch(int)
def g(self, x):
def methodg(self, x):
return x + 3


foo = Foo()
assert foo.f(1) == 2
assert foo.f(1.0) == 0.0
assert foo.g(1) == 4
assert foo.methodf(1) == 2
assert foo.methodf(1.0) == 0.0
assert foo.methodg(1) == 4


def test_methods_multiple_dispatch():
class Foo(object):
@dispatch(A, A)
def f(x, y):
def methodf(self, x, y):
return 1

@dispatch(A, C)
def f(x, y):
def methodf(self, x, y):
return 2


foo = Foo()
assert foo.f(A(), A()) == 1
assert foo.f(A(), C()) == 2
assert foo.f(C(), C()) == 2
assert foo.methodf(A(), A()) == 1
assert foo.methodf(A(), C()) == 2
assert foo.methodf(C(), C()) == 2


def test_methods_functions_collision():

class Foo(object):
@orig_dispatch(int)
def f(self, x):
return x + 1

@orig_dispatch(float)
def f(self, x):
return x - 1

@orig_dispatch(int)
def f(x):
return x + 10

foo = Foo()
assert foo.f(1) == 2
assert foo.f(1.0) == 0
assert f(1) == 11

class Foo(object):
@dispatch(int)
def f(self, x):
return x + 1

@dispatch(int)
def f(x):
return x + 10

foo = Foo()
assert raises(TypeError, lambda: foo.f(1))