diff --git a/src/py/flwr/client/app.py b/src/py/flwr/client/app.py index 15b39c470443..acb24efa84d0 100644 --- a/src/py/flwr/client/app.py +++ b/src/py/flwr/client/app.py @@ -33,7 +33,7 @@ from flwr.cli.config_utils import get_fab_metadata from flwr.cli.install import install_from_fab from flwr.client.client import Client -from flwr.client.client_app import ClientApp, LoadClientAppError +from flwr.client.client_app import ClientApp, LoadClientAppError, manage_client_app from flwr.client.clientapp.app import flwr_clientapp from flwr.client.nodestate.nodestate_factory import NodeStateFactory from flwr.client.typing import ClientFnExt @@ -567,7 +567,10 @@ def _on_backoff(retry_state: RetryState) -> None: ) # Execute ClientApp - reply_message = client_app(message=message, context=context) + with manage_client_app(client_app, context=context): + reply_message = client_app( + message=message, context=context + ) except Exception as ex: # pylint: disable=broad-exception-caught # Legacy grpc-bidi diff --git a/src/py/flwr/client/client_app.py b/src/py/flwr/client/client_app.py index 234d84f27782..f51e80f392a3 100644 --- a/src/py/flwr/client/client_app.py +++ b/src/py/flwr/client/client_app.py @@ -16,6 +16,8 @@ import inspect +from collections.abc import Iterator +from contextlib import contextmanager from typing import Callable, Optional from flwr.client.client import Client @@ -135,6 +137,8 @@ def ffn( self._train: Optional[ClientAppCallable] = None self._evaluate: Optional[ClientAppCallable] = None self._query: Optional[ClientAppCallable] = None + self._enter: Optional[Callable[[Context], None]] = None + self._exit: Optional[Callable[[Context], None]] = None def __call__(self, message: Message, context: Context) -> Message: """Execute `ClientApp`.""" @@ -249,6 +253,58 @@ def query_decorator(query_fn: ClientAppCallable) -> ClientAppCallable: return query_decorator + def enter(self) -> Callable[[Callable[[Context], None]], Callable[[Context], None]]: + """Return a decorator that registers the enter fn with the client app. + + Examples + -------- + >>> app = ClientApp() + >>> + >>> @app.enter() + >>> def enter(context: Context) -> None: + >>> print("ClientApp enter running") + """ + + def enter_decorator( + enter_fn: Callable[[Context], None] + ) -> Callable[[Context], None]: + """Register the enter fn with the ServerApp object.""" + warn_preview_feature("ClientApp-register-enter-function") + + # Register provided function with the ClientApp object + self._enter = enter_fn + + # Return provided function unmodified + return enter_fn + + return enter_decorator + + def exit(self) -> Callable[[Callable[[Context], None]], Callable[[Context], None]]: + """Return a decorator that registers the exit fn with the client app. + + Examples + -------- + >>> app = ClientApp() + >>> + >>> @app.exit() + >>> def exit(context: Context) -> None: + >>> print("ClientApp exit running") + """ + + def exit_decorator( + exit_fn: Callable[[Context], None] + ) -> Callable[[Context], None]: + """Register the exit fn with the ServerApp object.""" + warn_preview_feature("ClientApp-register-exit-function") + + # Register provided function with the ClientApp object + self._exit = exit_fn + + # Return provided function unmodified + return exit_fn + + return exit_decorator + class LoadClientAppError(Exception): """Error when trying to load `ClientApp`.""" @@ -283,3 +339,17 @@ def _registration_error(fn_name: str) -> ValueError: >>> ) """, ) + + +@contextmanager +def manage_client_app(app: ClientApp, context: Context) -> Iterator[None]: + """Manage the lifecycle of a ClientApp.""" + # pylint: disable=protected-access + try: + if app._enter is not None: + app._enter(context) + yield + finally: + if app._exit is not None: + app._exit(context) + # pylint: enable=protected-access diff --git a/src/py/flwr/client/clientapp/app.py b/src/py/flwr/client/clientapp/app.py index cef822a14e86..212ef12e3527 100644 --- a/src/py/flwr/client/clientapp/app.py +++ b/src/py/flwr/client/clientapp/app.py @@ -23,7 +23,7 @@ import grpc from flwr.cli.install import install_from_fab -from flwr.client.client_app import ClientApp, LoadClientAppError +from flwr.client.client_app import ClientApp, LoadClientAppError, manage_client_app from flwr.common import Context, Message from flwr.common.args import add_args_flwr_app_common from flwr.common.config import get_flwr_dir @@ -133,7 +133,8 @@ def run_clientapp( # pylint: disable=R0914 ) # Execute ClientApp - reply_message = client_app(message=message, context=context) + with manage_client_app(client_app, context=context): + reply_message = client_app(message=message, context=context) except Exception as ex: # pylint: disable=broad-exception-caught # Don't update/change NodeState diff --git a/src/py/flwr/server/run_serverapp.py b/src/py/flwr/server/run_serverapp.py index 49660c5ff077..563e14e41bc1 100644 --- a/src/py/flwr/server/run_serverapp.py +++ b/src/py/flwr/server/run_serverapp.py @@ -24,7 +24,7 @@ from flwr.common.object_ref import load_app from .driver import Driver -from .server_app import LoadServerAppError, ServerApp +from .server_app import LoadServerAppError, ServerApp, manage_server_app def run( @@ -60,7 +60,8 @@ def _load() -> ServerApp: server_app = _load() # Call ServerApp - server_app(driver=driver, context=context) + with manage_server_app(server_app, driver=driver, context=context): + server_app(driver=driver, context=context) log(DEBUG, "ServerApp finished running.") return context diff --git a/src/py/flwr/server/server_app.py b/src/py/flwr/server/server_app.py index 9d91be88e94e..66d34b81f3e3 100644 --- a/src/py/flwr/server/server_app.py +++ b/src/py/flwr/server/server_app.py @@ -15,6 +15,8 @@ """Flower ServerApp.""" +from collections.abc import Iterator +from contextlib import contextmanager from typing import Callable, Optional from flwr.common import Context @@ -45,7 +47,7 @@ def server_fn(context: Context): """ -class ServerApp: +class ServerApp: # pylint: disable=too-many-instance-attributes """Flower ServerApp. Examples @@ -105,6 +107,8 @@ def __init__( self._client_manager = client_manager self._server_fn = server_fn self._main: Optional[ServerAppCallable] = None + self._enter: Optional[ServerAppCallable] = None + self._exit: Optional[ServerAppCallable] = None def __call__(self, driver: Driver, context: Context) -> None: """Execute `ServerApp`.""" @@ -177,6 +181,70 @@ def main_decorator(main_fn: ServerAppCallable) -> ServerAppCallable: return main_decorator + def enter(self) -> Callable[[ServerAppCallable], ServerAppCallable]: + """Return a decorator that registers the enter fn with the server app. + + Examples + -------- + >>> app = ServerApp() + >>> + >>> @app.enter() + >>> def enter(driver: Driver, context: Context) -> None: + >>> print("ServerApp enter running") + """ + + def enter_decorator(enter_fn: ServerAppCallable) -> ServerAppCallable: + """Register the enter fn with the ServerApp object.""" + warn_preview_feature("ServerApp-register-enter-function") + + # Register provided function with the ServerApp object + self._enter = enter_fn + + # Return provided function unmodified + return enter_fn + + return enter_decorator + + def exit(self) -> Callable[[ServerAppCallable], ServerAppCallable]: + """Return a decorator that registers the exit fn with the server app. + + Examples + -------- + >>> app = ServerApp() + >>> + >>> @app.exit() + >>> def exit(context: Context) -> None: + >>> print("ServerApp exit running") + """ + + def exit_decorator(exit_fn: ServerAppCallable) -> ServerAppCallable: + """Register the exit fn with the ServerApp object.""" + warn_preview_feature("ServerApp-register-exit-function") + + # Register provided function with the ServerApp object + self._exit = exit_fn + + # Return provided function unmodified + return exit_fn + + return exit_decorator + class LoadServerAppError(Exception): """Error when trying to load `ServerApp`.""" + + +@contextmanager +def manage_server_app( + app: ServerApp, driver: Driver, context: Context +) -> Iterator[None]: + """Manage the lifecycle of a ServerApp.""" + # pylint: disable=protected-access + try: + if app._enter is not None: + app._enter(driver, context) + yield + finally: + if app._exit is not None: + app._exit(driver, context) + # pylint: enable=protected-access