diff --git a/src/textual/app.py b/src/textual/app.py index 6861012863..cbb29046ed 100644 --- a/src/textual/app.py +++ b/src/textual/app.py @@ -120,6 +120,7 @@ ScreenResultCallbackType, ScreenResultType, SystemModalScreen, + AwaitScreen, ) from textual.signal import Signal from textual.theme import BUILTIN_THEMES, Theme, ThemeProvider @@ -2650,14 +2651,14 @@ def push_screen( screen: Screen[ScreenResultType] | str, callback: ScreenResultCallbackType[ScreenResultType] | None = None, wait_for_dismiss: Literal[True] = True, - ) -> asyncio.Future[ScreenResultType]: ... + ) -> AwaitScreen[ScreenResultType]: ... def push_screen( self, screen: Screen[ScreenResultType] | str, callback: ScreenResultCallbackType[ScreenResultType] | None = None, wait_for_dismiss: bool = False, - ) -> AwaitMount | asyncio.Future[ScreenResultType]: + ) -> AwaitMount | AwaitScreen[ScreenResultType]: """Push a new [screen](/guide/screens) on the screen stack, making it the current screen. Args: @@ -2670,7 +2671,7 @@ def push_screen( NoActiveWorker: If using `wait_for_dismiss` outside of a worker. Returns: - An optional awaitable that awaits the mounting of the screen and its children, or an asyncio Future + An optional awaitable that awaits the mounting of the screen and its children, or an awaitable to await the result of the screen. """ if not isinstance(screen, (Screen, str)): @@ -2678,14 +2679,6 @@ def push_screen( f"push_screen requires a Screen instance or str; not {screen!r}" ) - try: - loop = asyncio.get_running_loop() - except RuntimeError: - # Mainly for testing, when push_screen isn't called in an async context - future: asyncio.Future[ScreenResultType] = asyncio.Future() - else: - future = loop.create_future() - if self._screen_stack: self.screen.post_message(events.ScreenSuspend()) self.screen.refresh() @@ -2695,6 +2688,7 @@ def push_screen( except LookupError: message_pump = self.app + future: AwaitScreen[ScreenResultType] = AwaitScreen() next_screen._push_result_callback(message_pump, callback, future) self._load_screen_css(next_screen) self._screen_stack.append(next_screen) diff --git a/src/textual/screen.py b/src/textual/screen.py index 1268858e22..ff57173bb9 100644 --- a/src/textual/screen.py +++ b/src/textual/screen.py @@ -8,6 +8,7 @@ from __future__ import annotations +import enum import asyncio from functools import partial from operator import attrgetter @@ -23,6 +24,8 @@ Optional, TypeVar, Union, + Literal, + Generator, ) import rich.repr @@ -82,6 +85,27 @@ ] """Type of a screen result callback function.""" +class _Unset(enum.Enum): + UNSET = enum.auto() + +class AwaitScreen(Generic[ScreenResultType]): + def __init__(self) -> None: + self._event = asyncio.Event() + self._result: ScreenResultType | Literal[_Unset.UNSET] = _Unset.UNSET + + async def wait(self) -> ScreenResultType: + await self._event.wait() + assert self._result is not _Unset.UNSET + return self._result + + def __await__(self) -> Generator[Any, Any, ScreenResultType]: + return self.wait().__await__() + + def set_result(self, result): + assert self._result is _Unset.UNSET + self._result = result + self._event.set() + @rich.repr.auto class ResultCallback(Generic[ScreenResultType]): @@ -91,7 +115,7 @@ def __init__( self, requester: MessagePump, callback: ScreenResultCallbackType[ScreenResultType] | None, - future: asyncio.Future[ScreenResultType] | None = None, + future: AwaitScreen[ScreenResultType] | None = None, ) -> None: """Initialise the result callback object. @@ -1161,7 +1185,7 @@ def _push_result_callback( self, requester: MessagePump, callback: ScreenResultCallbackType[ScreenResultType] | None, - future: asyncio.Future[ScreenResultType | None] | None = None, + future: AwaitScreen[ScreenResultType] | None = None, ) -> None: """Add a result callback to the screen.