Skip to content

Commit

Permalink
now works on descriptors too!
Browse files Browse the repository at this point in the history
  • Loading branch information
soldni committed Sep 8, 2022
1 parent d716c0b commit e9d1ade
Show file tree
Hide file tree
Showing 3 changed files with 150 additions and 27 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "trouting"
version = "0.2.2"
version = "0.3.0"
description = "Trouting (short for Type Routing) is a simple class decorator that allows to define multiple interfaces for a method that behave differently depending on input types."
authors = [
{name = "Luca Soldaini", email = "[email protected]" }
Expand Down
91 changes: 65 additions & 26 deletions src/trouting/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,11 @@
Callable,
Dict,
Generic,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
Union,
cast,
)

from typing_extensions import Concatenate, ParamSpec
Expand Down Expand Up @@ -47,7 +46,7 @@ def add_one_str(self, a: str) -> str:
interfaces: Dict[Tuple[type, ...], Callable[Concatenate[Any, P], R]]

def __init__(
self, interfaced_method: Callable[Concatenate[Any, P], R]
self, fallback_method: Callable[Concatenate[Any, P], R]
) -> None:
"""Create an Interface object.
Expand All @@ -57,9 +56,8 @@ def __init__(
"""
self.interfaces = {}
self.bounded_args = None
self._interfaced_method = interfaced_method
self._method_signature = inspect.signature(interfaced_method)
self._obj = None
self.fallback_method = fallback_method
self.is_descriptor = inspect.ismethoddescriptor(fallback_method)

def _expand_interface_combinations(
self, nested_interface_spec: Dict[str, Union[type, Tuple[type, ...]]]
Expand Down Expand Up @@ -98,7 +96,7 @@ def add_interface(
if self.bounded_args is None:
self.bounded_args = current_interface_args
elif self.bounded_args != current_interface_args:
raise ValueError(
raise TypeError(
"All interfaces must have the same arguments; the current "
f"interface has arguments {current_interface_args}, but the "
f"previous interface has arguments {self.bounded_args}"
Expand All @@ -107,45 +105,86 @@ def add_interface(
def _add_interface(
method: Callable[Concatenate[Any, P], R]
) -> "trouting":
if self.is_descriptor:
if not inspect.ismethoddescriptor(method):
raise TypeError(
"All interfaces must be descriptors; the current "
"interface is a function."
)
elif not isinstance(self.fallback_method, type(method)):
raise TypeError(
"All interfaces must be of the same type; the current "
f"interface is a {type(method)}, but the previous "
f"interface is {type(self.fallback_method)}."
)

for interface_spec in interface_specs:
# register the same method for all types in the interface spec
self.interfaces[tuple(interface_spec.values())] = method
# have to add an ignore because pyright is being a bit too
# clever here.
self.interfaces[ # pyright: ignore
tuple(interface_spec.values())
] = method

return self

return _add_interface

def __get__(
self, obj: Any, type: Optional[Type] = None
) -> Callable[Concatenate[P], R]:
def __get__(self, obj: Any, type: Any) -> Callable[Concatenate[P], R]:
"""Return a bound method that calls the correct interface."""
return partial(self.__call__, __obj__=obj)
return partial(
self.__call__, __trouting_obj__=obj, __trouting_type__=type
)

def _bound_method(self, method: Any, obj: Any, cls: Any) -> Callable:
if self.is_descriptor:
bound_method = method.__get__(obj, cast(type, cls))
else:
# populate the first argument with the object or class here
bound_method = partial(method, obj or cls)
return bound_method

def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R:
"""Call the interfaced method with the correct interface."""

if (obj := kwargs.pop("__obj__", MISSING)) is MISSING:
if (
__trouting_obj__ := kwargs.pop("__trouting_obj__", MISSING)
) is MISSING:
raise ValueError(
"__obj__ is required; `Interface._run_interface` "
"__trouting_obj__ is required; `Interface._run_interface` "
"was improperly called; You might have called a trouted "
"method in an invalid way; If you think you are using this "
"library correctly, please file a bug report."
)
if (
__trouting_type__ := kwargs.pop("__trouting_type__", MISSING)
) is MISSING:
raise ValueError(
"__trouting_type__ is required; `Interface._run_interface` "
"was improperly called; You might have called a trouted "
"method in an invalid way; If you think you are using this "
"library correctly, please file a bug report."
)

if self.bounded_args is None:
# no interfaces have been added, so we fall back to the default
return self._interfaced_method(obj, *args, **kwargs)
bounded_fallback_method = self._bound_method(
self.fallback_method, __trouting_obj__, __trouting_type__
)

sig_vals = self._method_signature.bind(self, *args, **kwargs)
method_to_call = None
sig_vals = inspect.signature(bounded_fallback_method).bind(
*args, **kwargs
)

current_types = (
current_types = tuple(
type(sig_vals.arguments[arg_name])
for arg_name in self.bounded_args
for arg_name in (self.bounded_args or tuple())
)

# fall back to the default method if we didn't find anything
method_to_call = self.interfaces.get(
tuple(current_types), self._interfaced_method
)
method_to_call = self.interfaces.get(current_types, None)
if method_to_call is None:
method_to_call = bounded_fallback_method
else:
method_to_call = self._bound_method(
method_to_call, __trouting_obj__, __trouting_type__
)

return method_to_call(obj, *args, **kwargs)
return method_to_call(*args, **kwargs)
84 changes: 84 additions & 0 deletions tests/test_decorators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from typing import Any
from unittest import TestCase

from trouting import trouting


class TroutedClass:
@trouting
@classmethod
def add_one(cls, a: Any) -> Any:
raise TypeError(f"Type {type(a)} not supported for +1")

@add_one.add_interface(a=int)
@classmethod
def add_one_int(cls, a: int) -> int:
return a + 1

@add_one.add_interface(a=str)
@classmethod
def add_one_str(cls, a: str) -> str:
return a + "1"

@trouting
def add_two(self, a: Any) -> Any:
raise TypeError(f"Type {type(a)} not supported for +2")

@add_two.add_interface(a=int)
def add_two_int(self, a: int) -> int:
return a + 2

@add_two.add_interface(a=str)
def add_two_str(self, a: str) -> str:
return a + "2"

@trouting
@staticmethod
def add_three(a: Any) -> Any:
raise TypeError(f"Type {type(a)} not supported for +3")

@add_three.add_interface(a=int)
@staticmethod
def add_three_int(a: int) -> int:
return a + 3

@add_three.add_interface(a=str)
@staticmethod
def add_three_str(a: str) -> str:
return a + "3"


class TestDecorators(TestCase):
def test_classmethod(self):
self.assertEqual(TroutedClass.add_one(1), 2)
self.assertEqual(TroutedClass.add_one("1"), "11")

def test_instance_method(self):
self.assertEqual(TroutedClass().add_two(1), 3)
self.assertEqual(TroutedClass().add_two("1"), "12")

def test_staticmethod(self):
# TODO[soldni]: need to fix typing annotation for trouting
# so that pylance doesn't freak out when using staticmethod
self.assertEqual(TroutedClass.add_three(1), 4) # pyright: ignore
self.assertEqual(TroutedClass.add_three("1"), "13") # pyright: ignore

def test_raise_error_uneven_interfaces(self):
class _:
@trouting
@classmethod
def add_one(cls, a: Any) -> Any:
raise TypeError(f"Type {type(a)} not supported for +1")

with self.assertRaises(TypeError):

@add_one.add_interface(a=int)
def add_one_int(cls, a: int) -> int:
return a + 1

with self.assertRaises(TypeError):
# Type ignore because this is intentionally wrong
@add_one.add_interface(a=str) # type: ignore
@staticmethod
def add_one_str(a: str) -> str:
return a + "1"

0 comments on commit e9d1ade

Please sign in to comment.