Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bind dispatched methods, and don't rely on "self" argument to infer methods #26

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion multipledispatch/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def dispatch(*types, **kwargs):
def _(func):
name = func.__name__

if ismethod(func):
if ismethod(func) and isinclass():
dispatcher = inspect.currentframe().f_back.f_locals.get(name,
MethodDispatcher(name))
else:
Expand All @@ -78,3 +78,11 @@ def ismethod(func):
"""
spec = inspect.getargspec(func)
return spec and spec.args and spec.args[0] == 'self'


def isinclass(n=1):
""" Is the nth previous frame in a class definition?"""
frame = inspect.currentframe().f_back # escape from current function
for _ in range(n):
frame = getattr(frame, 'f_back')
return '__module__' in frame.f_locals
12 changes: 7 additions & 5 deletions multipledispatch/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,14 +203,16 @@ class MethodDispatcher(Dispatcher):
Dispatcher
"""
def __get__(self, instance, owner):
self.obj = instance
self.cls = owner
return self
dispatcher = self
def method(self, *args, **kwargs):
return dispatcher(self, *args, **kwargs)
method.__name__ = self.name
return method.__get__(instance, owner)

def __call__(self, *args, **kwargs):
def __call__(self, obj, *args, **kwargs):
types = tuple([type(arg) for arg in args])
func = self.resolve(types)
return func(self.obj, *args, **kwargs)
return func(obj, *args, **kwargs)


def str_signature(sig):
Expand Down
70 changes: 70 additions & 0 deletions multipledispatch/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,21 @@ def g(self, x):


def test_methods_multiple_dispatch():
class Foo(object):
@dispatch(A)
def f(self, y):
return 1

@dispatch(C)
def f(self, y):
return 2

foo = Foo()
assert foo.f(A()) == 1
assert foo.f(C()) == 2


def test_methods_multiple_dispatch_fail():
class Foo(object):
@dispatch(A, A)
def f(x, y):
Expand All @@ -198,8 +213,63 @@ def f(x, y):
def f(x, y):
return 2

@dispatch(int)
def f(x, y): # 'x' as self
return 1 + y

foo = Foo()
# We require the 'self' argument to be used to infer methods
assert foo.f(A(), A()) == 1
assert foo.f(A(), C()) == 2
assert foo.f(C(), C()) == 2
assert raises(TypeError, lambda: foo.f(2))


def test_function_with_self():
@dispatch(A, A)
def f(self, x):
return 1

@dispatch(A, C)
def f(self, x):
return 2

@dispatch(C, A)
def f(self, x):
return 3

@dispatch(C, C)
def f(self, x):
return 4

assert f(A(), A()) == 1
assert f(A(), C()) == 2
assert f(C(), A()) == 3
assert f(C(), C()) == 4


def test_method_dispatch_is_safe():
class Foo(object):
def __init__(self, x):
self.x = x

@dispatch(int)
def f(self, y):
return self.x + y

@dispatch(float)
def f(self, y):
return self.x - y

foo1 = Foo(1)
foo2 = Foo(2)
assert foo1.f(1) == 2
assert foo1.f(1.0) == 0.0
assert foo2.f(1) == 3
assert foo2.f(1.0) == 1.0
f1 = foo1.f
f2 = foo2.f
assert f1(1) == 2
assert f1(1.0) == 0.0
assert f2(1) == 3
assert f2(1.0) == 1.0