From 01b8e31bbb1d80028cae3a5209013680ef1cbae0 Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Tue, 10 Dec 2024 20:51:22 +0000 Subject: [PATCH 1/4] Emit correct type stubs for async functions wrapped with additional decorators --- src/synchronicity/synchronizer.py | 16 ++++++++++++++-- src/synchronicity/type_stubs.py | 8 +++++++- test/type_stub_test.py | 10 ++++++++++ 3 files changed, 31 insertions(+), 3 deletions(-) diff --git a/src/synchronicity/synchronizer.py b/src/synchronicity/synchronizer.py index e911e72..eba693c 100644 --- a/src/synchronicity/synchronizer.py +++ b/src/synchronicity/synchronizer.py @@ -56,6 +56,18 @@ ) +def iscoroutinefunction(func): + if hasattr(func, "__wrapped__"): + return iscoroutinefunction(func.__wrapped__) + return inspect.iscoroutinefunction(func) + + +def isasyncgenfunction(func): + if hasattr(func, "__wrapped__"): + return isasyncgenfunction(func.__wrapped__) + return inspect.isasyncgenfunction(func) + + def _type_requires_aio_usage(annotation, declaration_module): if isinstance(annotation, ForwardRef): annotation = annotation.__forward_arg__ @@ -79,7 +91,7 @@ def _type_requires_aio_usage(annotation, declaration_module): def should_have_aio_interface(func): # determines if a blocking function gets an .aio attribute with an async interface to the function or not - if inspect.iscoroutinefunction(func) or inspect.isasyncgenfunction(func): + if iscoroutinefunction(func) or isasyncgenfunction(func): return True # check annotations if they contain any async entities that would need an event loop to be translated: # This catches things like vanilla functions returning Coroutines @@ -468,7 +480,7 @@ def _wrap_callable( else: _name = name - is_coroutinefunction = inspect.iscoroutinefunction(f) + is_coroutinefunction = iscoroutinefunction(f) @wraps_by_interface(interface, f) def f_wrapped(*args, **kwargs): diff --git a/src/synchronicity/type_stubs.py b/src/synchronicity/type_stubs.py index caa3e28..6a43dfa 100644 --- a/src/synchronicity/type_stubs.py +++ b/src/synchronicity/type_stubs.py @@ -390,6 +390,7 @@ def final_transform_signature(sig): {aio_func_source} {body_indent}{entity_name}: __{entity_name}_spec{parent_type_var_names_spec} """ + return protocol_attr def _prepare_method_generic_type_vars(self, entity, parent_generic_type_vars): @@ -855,8 +856,13 @@ def _get_function_source( self.imports.add("typing_extensions") maybe_decorators = f"{signature_indent}@typing_extensions.dataclass_transform({args})\n" + def is_async(func): + if hasattr(func, "__wrapped__") and not hasattr(func, SYNCHRONIZER_ATTR): + return is_async(func.__wrapped__) + return inspect.iscoroutinefunction(func) + async_prefix = "" - if inspect.iscoroutinefunction(func): + if is_async(func): # note: async prefix should not be used for annotated abstract/stub *async generators*, # so we don't check for inspect.isasyncgenfunction since they contain no yield keyword, # and would otherwise indicate an awaitable that returns an async generator to static type checkers diff --git a/test/type_stub_test.py b/test/type_stub_test.py index 973cc59..02a5663 100644 --- a/test/type_stub_test.py +++ b/test/type_stub_test.py @@ -178,6 +178,16 @@ def wrapper(extra_arg: int, *args, **kwargs): assert _function_source(wrapper) == "def orig(extra_arg: int, arg: float):\n ...\n" +def test_wrapped_async_func_remains_async(): + async def orig(arg: str): ... + + @functools.wraps(orig) + def wrapper(*args, **kwargs): + return orig(*args, **kwargs) + + assert _function_source(wrapper) == "async def orig(arg: str):\n ...\n" + + class Base: def base_method(self) -> str: return "" From eee04c0e06bb5c48c7b3934ca7c6138a53b07071 Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Wed, 11 Dec 2024 17:09:17 +0000 Subject: [PATCH 2/4] Address comments --- src/synchronicity/async_wrap.py | 9 +++++++++ src/synchronicity/synchronizer.py | 12 ++++++------ src/synchronicity/type_stubs.py | 5 +++-- 3 files changed, 18 insertions(+), 8 deletions(-) diff --git a/src/synchronicity/async_wrap.py b/src/synchronicity/async_wrap.py index e81cec3..64ba0d8 100644 --- a/src/synchronicity/async_wrap.py +++ b/src/synchronicity/async_wrap.py @@ -38,6 +38,15 @@ async def wrapper(*args, **kwargs): return functools.wraps(func) +def is_coroutine_function_follow_wrapped(func: typing.Callable) -> bool: + """Determine if a function returns a coroutine, unwrapping decorators, but not the async synchronicitiy interace.""" + from .synchronizer import TARGET_INTERFACE_ATTR # Avoid circular import + + if hasattr(func, "__wrapped__") and getattr(func, TARGET_INTERFACE_ATTR, None) != Interface.BLOCKING: + return is_coroutine_function_follow_wrapped(func.__wrapped__) + return inspect.iscoroutinefunction(func) + + YIELD_TYPE = typing.TypeVar("YIELD_TYPE") SEND_TYPE = typing.TypeVar("SEND_TYPE") diff --git a/src/synchronicity/synchronizer.py b/src/synchronicity/synchronizer.py index eba693c..79080ea 100644 --- a/src/synchronicity/synchronizer.py +++ b/src/synchronicity/synchronizer.py @@ -56,15 +56,15 @@ ) -def iscoroutinefunction(func): +def iscoroutinefunction_follow_wrapped(func): if hasattr(func, "__wrapped__"): - return iscoroutinefunction(func.__wrapped__) + return iscoroutinefunction_follow_wrapped(func.__wrapped__) return inspect.iscoroutinefunction(func) -def isasyncgenfunction(func): +def isasyncgenfunction_follow_wrapped(func): if hasattr(func, "__wrapped__"): - return isasyncgenfunction(func.__wrapped__) + return isasyncgenfunction_follow_wrapped(func.__wrapped__) return inspect.isasyncgenfunction(func) @@ -91,7 +91,7 @@ def _type_requires_aio_usage(annotation, declaration_module): def should_have_aio_interface(func): # determines if a blocking function gets an .aio attribute with an async interface to the function or not - if iscoroutinefunction(func) or isasyncgenfunction(func): + if iscoroutinefunction_follow_wrapped(func) or isasyncgenfunction_follow_wrapped(func): return True # check annotations if they contain any async entities that would need an event loop to be translated: # This catches things like vanilla functions returning Coroutines @@ -480,7 +480,7 @@ def _wrap_callable( else: _name = name - is_coroutinefunction = iscoroutinefunction(f) + is_coroutinefunction = iscoroutinefunction_follow_wrapped(f) @wraps_by_interface(interface, f) def f_wrapped(*args, **kwargs): diff --git a/src/synchronicity/type_stubs.py b/src/synchronicity/type_stubs.py index 6a43dfa..d9e1e93 100644 --- a/src/synchronicity/type_stubs.py +++ b/src/synchronicity/type_stubs.py @@ -27,6 +27,7 @@ import synchronicity from synchronicity import combined_types, overload_tracking from synchronicity.annotations import TYPE_CHECKING_OVERRIDES, evaluated_annotation +from synchronicity.async_wrap import is_coroutine_function_follow_wrapped from synchronicity.interface import Interface from synchronicity.synchronizer import ( SYNCHRONIZER_ATTR, @@ -857,12 +858,12 @@ def _get_function_source( maybe_decorators = f"{signature_indent}@typing_extensions.dataclass_transform({args})\n" def is_async(func): - if hasattr(func, "__wrapped__") and not hasattr(func, SYNCHRONIZER_ATTR): + if hasattr(func, "__wrapped__") and getattr(func, TARGET_INTERFACE_ATTR, None) != Interface.BLOCKING: return is_async(func.__wrapped__) return inspect.iscoroutinefunction(func) async_prefix = "" - if is_async(func): + if is_coroutine_function_follow_wrapped(func): # note: async prefix should not be used for annotated abstract/stub *async generators*, # so we don't check for inspect.isasyncgenfunction since they contain no yield keyword, # and would otherwise indicate an awaitable that returns an async generator to static type checkers From a1f7e004bd976ac532b562d123bf81c71a39ae5d Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Mon, 16 Dec 2024 17:09:32 +0000 Subject: [PATCH 3/4] Simplify --- src/synchronicity/async_wrap.py | 12 +++++++++++- src/synchronicity/synchronizer.py | 20 +++----------------- src/synchronicity/type_stubs.py | 5 ----- 3 files changed, 14 insertions(+), 23 deletions(-) diff --git a/src/synchronicity/async_wrap.py b/src/synchronicity/async_wrap.py index 64ba0d8..27f92b0 100644 --- a/src/synchronicity/async_wrap.py +++ b/src/synchronicity/async_wrap.py @@ -39,7 +39,7 @@ async def wrapper(*args, **kwargs): def is_coroutine_function_follow_wrapped(func: typing.Callable) -> bool: - """Determine if a function returns a coroutine, unwrapping decorators, but not the async synchronicitiy interace.""" + """Determine if a function returns a coroutine, unwrapping decorators, but not the async synchronicity interace.""" from .synchronizer import TARGET_INTERFACE_ATTR # Avoid circular import if hasattr(func, "__wrapped__") and getattr(func, TARGET_INTERFACE_ATTR, None) != Interface.BLOCKING: @@ -47,6 +47,16 @@ def is_coroutine_function_follow_wrapped(func: typing.Callable) -> bool: return inspect.iscoroutinefunction(func) +def is_async_gen_function_follow_wrapped(func: typing.Callable) -> bool: + """Determine if a function returns an async generator, unwrapping decorators, but not the async synchronicity interace.""" + from .synchronizer import TARGET_INTERFACE_ATTR # Avoid circular import + + if hasattr(func, "__wrapped__") and getattr(func, TARGET_INTERFACE_ATTR, None) != Interface.BLOCKING: + return is_async_gen_function_follow_wrapped(func.__wrapped__) + return inspect.isasyncgenfunction(func) + + + YIELD_TYPE = typing.TypeVar("YIELD_TYPE") SEND_TYPE = typing.TypeVar("SEND_TYPE") diff --git a/src/synchronicity/synchronizer.py b/src/synchronicity/synchronizer.py index 79080ea..d671f12 100644 --- a/src/synchronicity/synchronizer.py +++ b/src/synchronicity/synchronizer.py @@ -20,7 +20,7 @@ from synchronicity.annotations import evaluated_annotation from synchronicity.combined_types import FunctionWithAio, MethodWithAio -from .async_wrap import wraps_by_interface +from .async_wrap import is_async_gen_function_follow_wrapped, is_coroutine_function_follow_wrapped, wraps_by_interface from .callback import Callback from .exceptions import UserCodeException, unwrap_coro_exception, wrap_coro_exception from .interface import DEFAULT_CLASS_PREFIX, DEFAULT_FUNCTION_PREFIXES, Interface @@ -56,18 +56,6 @@ ) -def iscoroutinefunction_follow_wrapped(func): - if hasattr(func, "__wrapped__"): - return iscoroutinefunction_follow_wrapped(func.__wrapped__) - return inspect.iscoroutinefunction(func) - - -def isasyncgenfunction_follow_wrapped(func): - if hasattr(func, "__wrapped__"): - return isasyncgenfunction_follow_wrapped(func.__wrapped__) - return inspect.isasyncgenfunction(func) - - def _type_requires_aio_usage(annotation, declaration_module): if isinstance(annotation, ForwardRef): annotation = annotation.__forward_arg__ @@ -91,7 +79,7 @@ def _type_requires_aio_usage(annotation, declaration_module): def should_have_aio_interface(func): # determines if a blocking function gets an .aio attribute with an async interface to the function or not - if iscoroutinefunction_follow_wrapped(func) or isasyncgenfunction_follow_wrapped(func): + if is_coroutine_function_follow_wrapped(func) or is_async_gen_function_follow_wrapped(func): return True # check annotations if they contain any async entities that would need an event loop to be translated: # This catches things like vanilla functions returning Coroutines @@ -480,8 +468,6 @@ def _wrap_callable( else: _name = name - is_coroutinefunction = iscoroutinefunction_follow_wrapped(f) - @wraps_by_interface(interface, f) def f_wrapped(*args, **kwargs): return_future = kwargs.pop(_RETURN_FUTURE_KWARG, False) @@ -511,7 +497,7 @@ def f_wrapped(*args, **kwargs): elif is_coroutine: if interface == Interface._ASYNC_WITH_BLOCKING_TYPES: coro = self._run_function_async(res, f) - if not is_coroutinefunction: + if not is_coroutine_function_follow_wrapped(f): # If this is a non-async function that returns a coroutine, # then this is the exit point, and we need to unwrap any # wrapped exception here. Otherwise, the exit point is diff --git a/src/synchronicity/type_stubs.py b/src/synchronicity/type_stubs.py index d9e1e93..1d3c062 100644 --- a/src/synchronicity/type_stubs.py +++ b/src/synchronicity/type_stubs.py @@ -857,11 +857,6 @@ def _get_function_source( self.imports.add("typing_extensions") maybe_decorators = f"{signature_indent}@typing_extensions.dataclass_transform({args})\n" - def is_async(func): - if hasattr(func, "__wrapped__") and getattr(func, TARGET_INTERFACE_ATTR, None) != Interface.BLOCKING: - return is_async(func.__wrapped__) - return inspect.iscoroutinefunction(func) - async_prefix = "" if is_coroutine_function_follow_wrapped(func): # note: async prefix should not be used for annotated abstract/stub *async generators*, From 8b4b763ab7930963ee0732b6d81210c068ae0bda Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Mon, 16 Dec 2024 17:19:14 +0000 Subject: [PATCH 4/4] Lint --- src/synchronicity/async_wrap.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/synchronicity/async_wrap.py b/src/synchronicity/async_wrap.py index 27f92b0..f60a9c1 100644 --- a/src/synchronicity/async_wrap.py +++ b/src/synchronicity/async_wrap.py @@ -39,7 +39,7 @@ async def wrapper(*args, **kwargs): def is_coroutine_function_follow_wrapped(func: typing.Callable) -> bool: - """Determine if a function returns a coroutine, unwrapping decorators, but not the async synchronicity interace.""" + """Determine if func returns a coroutine, unwrapping decorators, but not the async synchronicity interace.""" from .synchronizer import TARGET_INTERFACE_ATTR # Avoid circular import if hasattr(func, "__wrapped__") and getattr(func, TARGET_INTERFACE_ATTR, None) != Interface.BLOCKING: @@ -48,7 +48,7 @@ def is_coroutine_function_follow_wrapped(func: typing.Callable) -> bool: def is_async_gen_function_follow_wrapped(func: typing.Callable) -> bool: - """Determine if a function returns an async generator, unwrapping decorators, but not the async synchronicity interace.""" + """Determine if func returns an async generator, unwrapping decorators, but not the async synchronicity interace.""" from .synchronizer import TARGET_INTERFACE_ATTR # Avoid circular import if hasattr(func, "__wrapped__") and getattr(func, TARGET_INTERFACE_ATTR, None) != Interface.BLOCKING: @@ -56,7 +56,6 @@ def is_async_gen_function_follow_wrapped(func: typing.Callable) -> bool: return inspect.isasyncgenfunction(func) - YIELD_TYPE = typing.TypeVar("YIELD_TYPE") SEND_TYPE = typing.TypeVar("SEND_TYPE")