Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
mwaskom committed Dec 11, 2024
1 parent 076ba13 commit b06bfa8
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 8 deletions.
9 changes: 9 additions & 0 deletions src/synchronicity/async_wrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,15 @@ async def wrapper(*args, **kwargs):
return functools.wraps(func)


def is_coroutine_function_follow_wrapped(func: typing.Callable) -> bool:
"""Determine if a function returns a coroutine, unwrapping decorators, but not the async synchronicitiy interace."""
from .synchronizer import TARGET_INTERFACE_ATTR # Avoid circular import

if hasattr(func, "__wrapped__") and getattr(func, TARGET_INTERFACE_ATTR, None) != Interface.BLOCKING:
return is_coroutine_function_follow_wrapped(func.__wrapped__)
return inspect.iscoroutinefunction(func)


YIELD_TYPE = typing.TypeVar("YIELD_TYPE")
SEND_TYPE = typing.TypeVar("SEND_TYPE")

Expand Down
12 changes: 6 additions & 6 deletions src/synchronicity/synchronizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,15 @@
)


def iscoroutinefunction(func):
def iscoroutinefunction_follow_wrapped(func):
if hasattr(func, "__wrapped__"):
return iscoroutinefunction(func.__wrapped__)
return iscoroutinefunction_follow_wrapped(func.__wrapped__)
return inspect.iscoroutinefunction(func)


def isasyncgenfunction(func):
def isasyncgenfunction_follow_wrapped(func):
if hasattr(func, "__wrapped__"):
return isasyncgenfunction(func.__wrapped__)
return isasyncgenfunction_follow_wrapped(func.__wrapped__)
return inspect.isasyncgenfunction(func)


Expand All @@ -91,7 +91,7 @@ def _type_requires_aio_usage(annotation, declaration_module):

def should_have_aio_interface(func):
# determines if a blocking function gets an .aio attribute with an async interface to the function or not
if iscoroutinefunction(func) or isasyncgenfunction(func):
if iscoroutinefunction_follow_wrapped(func) or isasyncgenfunction_follow_wrapped(func):
return True
# check annotations if they contain any async entities that would need an event loop to be translated:
# This catches things like vanilla functions returning Coroutines
Expand Down Expand Up @@ -480,7 +480,7 @@ def _wrap_callable(
else:
_name = name

is_coroutinefunction = iscoroutinefunction(f)
is_coroutinefunction = iscoroutinefunction_follow_wrapped(f)

@wraps_by_interface(interface, f)
def f_wrapped(*args, **kwargs):
Expand Down
5 changes: 3 additions & 2 deletions src/synchronicity/type_stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import synchronicity
from synchronicity import combined_types, overload_tracking
from synchronicity.annotations import evaluated_annotation
from synchronicity.async_wrap import is_coroutine_function_follow_wrapped
from synchronicity.interface import Interface
from synchronicity.synchronizer import (
SYNCHRONIZER_ATTR,
Expand Down Expand Up @@ -845,12 +846,12 @@ def _get_function_source(
maybe_decorators = f"{signature_indent}@typing_extensions.dataclass_transform({args})\n"

def is_async(func):
if hasattr(func, "__wrapped__") and not hasattr(func, SYNCHRONIZER_ATTR):
if hasattr(func, "__wrapped__") and getattr(func, TARGET_INTERFACE_ATTR, None) != Interface.BLOCKING:
return is_async(func.__wrapped__)
return inspect.iscoroutinefunction(func)

async_prefix = ""
if is_async(func):
if is_coroutine_function_follow_wrapped(func):
# note: async prefix should not be used for annotated abstract/stub *async generators*,
# so we don't check for inspect.isasyncgenfunction since they contain no yield keyword,
# and would otherwise indicate an awaitable that returns an async generator to static type checkers
Expand Down

0 comments on commit b06bfa8

Please sign in to comment.