Skip to content

Commit

Permalink
Emit correct type stubs for async functions wrapped with additional d…
Browse files Browse the repository at this point in the history
…ecorators
  • Loading branch information
mwaskom committed Dec 10, 2024
1 parent 181bb89 commit f9b01a7
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 3 deletions.
16 changes: 14 additions & 2 deletions src/synchronicity/synchronizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,18 @@
)


def iscoroutinefunction(func):
if hasattr(func, "__wrapped__"):
return iscoroutinefunction(func.__wrapped__)
return inspect.iscoroutinefunction(func)


def isasyncgenfunction(func):
if hasattr(func, "__wrapped__"):
return isasyncgenfunction(func.__wrapped__)
return inspect.isasyncgenfunction(func)


def _type_requires_aio_usage(annotation, declaration_module):
if isinstance(annotation, ForwardRef):
annotation = annotation.__forward_arg__
Expand All @@ -79,7 +91,7 @@ def _type_requires_aio_usage(annotation, declaration_module):

def should_have_aio_interface(func):
# determines if a blocking function gets an .aio attribute with an async interface to the function or not
if inspect.iscoroutinefunction(func) or inspect.isasyncgenfunction(func):
if iscoroutinefunction(func) or isasyncgenfunction(func):
return True
# check annotations if they contain any async entities that would need an event loop to be translated:
# This catches things like vanilla functions returning Coroutines
Expand Down Expand Up @@ -468,7 +480,7 @@ def _wrap_callable(
else:
_name = name

is_coroutinefunction = inspect.iscoroutinefunction(f)
is_coroutinefunction = iscoroutinefunction(f)

@wraps_by_interface(interface, f)
def f_wrapped(*args, **kwargs):
Expand Down
8 changes: 7 additions & 1 deletion src/synchronicity/type_stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,7 @@ def final_transform_signature(sig):
{aio_func_source}
{body_indent}{entity_name}: __{entity_name}_spec{parent_type_var_names_spec}
"""

return protocol_attr

def _prepare_method_generic_type_vars(self, entity, parent_generic_type_vars):
Expand Down Expand Up @@ -843,8 +844,13 @@ def _get_function_source(
self.imports.add("typing_extensions")
maybe_decorators = f"{signature_indent}@typing_extensions.dataclass_transform({args})\n"

def is_async(func):
if hasattr(func, "__wrapped__") and not hasattr(func, SYNCHRONIZER_ATTR):
return is_async(func.__wrapped__)
return inspect.iscoroutinefunction(func)

async_prefix = ""
if inspect.iscoroutinefunction(func):
if is_async(func):
# note: async prefix should not be used for annotated abstract/stub *async generators*,
# so we don't check for inspect.isasyncgenfunction since they contain no yield keyword,
# and would otherwise indicate an awaitable that returns an async generator to static type checkers
Expand Down
11 changes: 11 additions & 0 deletions test/type_stub_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,17 @@ def wrapper(extra_arg: int, *args, **kwargs):
assert _function_source(wrapper) == "def orig(extra_arg: int, arg: float):\n ...\n"


def test_wrapped_function_preserves_color():
async def orig(arg: str):
...

@functools.wraps(orig)
def wrapper(*args, **kwargs):
return orig(*args, **kwargs)

assert _function_source(wrapper) == "async def orig(arg: str):\n ...\n"


class Base:
def base_method(self) -> str:
return ""
Expand Down

0 comments on commit f9b01a7

Please sign in to comment.