diff --git a/reflex/app_mixins/lifespan.py b/reflex/app_mixins/lifespan.py index 2b5ed8b58a..ef882a2ead 100644 --- a/reflex/app_mixins/lifespan.py +++ b/reflex/app_mixins/lifespan.py @@ -6,11 +6,13 @@ import contextlib import functools import inspect -import sys from typing import Callable, Coroutine, Set, Union from fastapi import FastAPI +from reflex.utils import console +from reflex.utils.exceptions import InvalidLifespanTaskType + from .mixin import AppMixin @@ -26,6 +28,7 @@ async def _run_lifespan_tasks(self, app: FastAPI): try: async with contextlib.AsyncExitStack() as stack: for task in self.lifespan_tasks: + run_msg = f"Started lifespan task: {task.__name__} as {{type}}" # type: ignore if isinstance(task, asyncio.Task): running_tasks.append(task) else: @@ -35,15 +38,19 @@ async def _run_lifespan_tasks(self, app: FastAPI): _t = task() if isinstance(_t, contextlib._AsyncGeneratorContextManager): await stack.enter_async_context(_t) + console.debug(run_msg.format(type="asynccontextmanager")) elif isinstance(_t, Coroutine): - running_tasks.append(asyncio.create_task(_t)) + task_ = asyncio.create_task(_t) + task_.add_done_callback(lambda t: t.result()) + running_tasks.append(task_) + console.debug(run_msg.format(type="coroutine")) + else: + console.debug(run_msg.format(type="function")) yield finally: - cancel_kwargs = ( - {"msg": "lifespan_cleanup"} if sys.version_info >= (3, 9) else {} - ) for task in running_tasks: - task.cancel(**cancel_kwargs) + console.debug(f"Canceling lifespan task: {task}") + task.cancel(msg="lifespan_cleanup") def register_lifespan_task(self, task: Callable | asyncio.Task, **task_kwargs): """Register a task to run during the lifespan of the app. @@ -51,7 +58,18 @@ def register_lifespan_task(self, task: Callable | asyncio.Task, **task_kwargs): Args: task: The task to register. task_kwargs: The kwargs of the task. + + Raises: + InvalidLifespanTaskType: If the task is a generator function. """ + if inspect.isgeneratorfunction(task) or inspect.isasyncgenfunction(task): + raise InvalidLifespanTaskType( + f"Task {task.__name__} of type generator must be decorated with contextlib.asynccontextmanager." + ) + if task_kwargs: + original_task = task task = functools.partial(task, **task_kwargs) # type: ignore + functools.update_wrapper(task, original_task) # type: ignore self.lifespan_tasks.add(task) # type: ignore + console.debug(f"Registered lifespan task: {task.__name__}") # type: ignore diff --git a/reflex/utils/exceptions.py b/reflex/utils/exceptions.py index 95d68c3b8b..7c35328612 100644 --- a/reflex/utils/exceptions.py +++ b/reflex/utils/exceptions.py @@ -111,3 +111,7 @@ class GeneratedCodeHasNoFunctionDefs(ReflexError): class PrimitiveUnserializableToJSON(ReflexError, ValueError): """Raised when a primitive type is unserializable to JSON. Usually with NaN and Infinity.""" + + +class InvalidLifespanTaskType(ReflexError, TypeError): + """Raised when an invalid task type is registered as a lifespan task."""