Skip to content

Commit

Permalink
enable app.enter and app.exit
Browse files Browse the repository at this point in the history
  • Loading branch information
panh99 committed Feb 11, 2025
1 parent 4f0201a commit 7aa9d31
Show file tree
Hide file tree
Showing 5 changed files with 150 additions and 7 deletions.
7 changes: 5 additions & 2 deletions src/py/flwr/client/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from flwr.cli.config_utils import get_fab_metadata
from flwr.cli.install import install_from_fab
from flwr.client.client import Client
from flwr.client.client_app import ClientApp, LoadClientAppError
from flwr.client.client_app import ClientApp, LoadClientAppError, manage_client_app
from flwr.client.clientapp.app import flwr_clientapp
from flwr.client.nodestate.nodestate_factory import NodeStateFactory
from flwr.client.typing import ClientFnExt
Expand Down Expand Up @@ -567,7 +567,10 @@ def _on_backoff(retry_state: RetryState) -> None:
)

# Execute ClientApp
reply_message = client_app(message=message, context=context)
with manage_client_app(client_app, context=context):
reply_message = client_app(
message=message, context=context
)
except Exception as ex: # pylint: disable=broad-exception-caught

# Legacy grpc-bidi
Expand Down
70 changes: 70 additions & 0 deletions src/py/flwr/client/client_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@


import inspect
from collections.abc import Iterator
from contextlib import contextmanager
from typing import Callable, Optional

from flwr.client.client import Client
Expand Down Expand Up @@ -135,6 +137,8 @@ def ffn(
self._train: Optional[ClientAppCallable] = None
self._evaluate: Optional[ClientAppCallable] = None
self._query: Optional[ClientAppCallable] = None
self._enter: Optional[Callable[[Context], None]] = None
self._exit: Optional[Callable[[Context], None]] = None

def __call__(self, message: Message, context: Context) -> Message:
"""Execute `ClientApp`."""
Expand Down Expand Up @@ -249,6 +253,58 @@ def query_decorator(query_fn: ClientAppCallable) -> ClientAppCallable:

return query_decorator

def enter(self) -> Callable[[Callable[[Context], None]], Callable[[Context], None]]:
"""Return a decorator that registers the enter fn with the client app.
Examples
--------
>>> app = ClientApp()
>>>
>>> @app.enter()
>>> def enter(context: Context) -> None:
>>> print("ClientApp enter running")
"""

def enter_decorator(
enter_fn: Callable[[Context], None]
) -> Callable[[Context], None]:
"""Register the enter fn with the ServerApp object."""
warn_preview_feature("ClientApp-register-enter-function")

# Register provided function with the ClientApp object
self._enter = enter_fn

# Return provided function unmodified
return enter_fn

return enter_decorator

def exit(self) -> Callable[[Callable[[Context], None]], Callable[[Context], None]]:
"""Return a decorator that registers the exit fn with the client app.
Examples
--------
>>> app = ClientApp()
>>>
>>> @app.exit()
>>> def exit(context: Context) -> None:
>>> print("ClientApp exit running")
"""

def exit_decorator(
exit_fn: Callable[[Context], None]
) -> Callable[[Context], None]:
"""Register the exit fn with the ServerApp object."""
warn_preview_feature("ClientApp-register-exit-function")

# Register provided function with the ClientApp object
self._exit = exit_fn

# Return provided function unmodified
return exit_fn

return exit_decorator


class LoadClientAppError(Exception):
"""Error when trying to load `ClientApp`."""
Expand Down Expand Up @@ -283,3 +339,17 @@ def _registration_error(fn_name: str) -> ValueError:
>>> )
""",
)


@contextmanager
def manage_client_app(app: ClientApp, context: Context) -> Iterator[None]:
"""Manage the lifecycle of a ClientApp."""
# pylint: disable=protected-access
try:
if app._enter is not None:
app._enter(context)
yield
finally:
if app._exit is not None:
app._exit(context)
# pylint: enable=protected-access
5 changes: 3 additions & 2 deletions src/py/flwr/client/clientapp/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import grpc

from flwr.cli.install import install_from_fab
from flwr.client.client_app import ClientApp, LoadClientAppError
from flwr.client.client_app import ClientApp, LoadClientAppError, manage_client_app
from flwr.common import Context, Message
from flwr.common.args import add_args_flwr_app_common
from flwr.common.config import get_flwr_dir
Expand Down Expand Up @@ -133,7 +133,8 @@ def run_clientapp( # pylint: disable=R0914
)

# Execute ClientApp
reply_message = client_app(message=message, context=context)
with manage_client_app(client_app, context=context):
reply_message = client_app(message=message, context=context)

except Exception as ex: # pylint: disable=broad-exception-caught
# Don't update/change NodeState
Expand Down
5 changes: 3 additions & 2 deletions src/py/flwr/server/run_serverapp.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from flwr.common.object_ref import load_app

from .driver import Driver
from .server_app import LoadServerAppError, ServerApp
from .server_app import LoadServerAppError, ServerApp, manage_server_app


def run(
Expand Down Expand Up @@ -60,7 +60,8 @@ def _load() -> ServerApp:
server_app = _load()

# Call ServerApp
server_app(driver=driver, context=context)
with manage_server_app(server_app, driver=driver, context=context):
server_app(driver=driver, context=context)

log(DEBUG, "ServerApp finished running.")
return context
Expand Down
70 changes: 69 additions & 1 deletion src/py/flwr/server/server_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
"""Flower ServerApp."""


from collections.abc import Iterator
from contextlib import contextmanager
from typing import Callable, Optional

from flwr.common import Context
Expand Down Expand Up @@ -45,7 +47,7 @@ def server_fn(context: Context):
"""


class ServerApp:
class ServerApp: # pylint: disable=too-many-instance-attributes
"""Flower ServerApp.
Examples
Expand Down Expand Up @@ -105,6 +107,8 @@ def __init__(
self._client_manager = client_manager
self._server_fn = server_fn
self._main: Optional[ServerAppCallable] = None
self._enter: Optional[ServerAppCallable] = None
self._exit: Optional[ServerAppCallable] = None

def __call__(self, driver: Driver, context: Context) -> None:
"""Execute `ServerApp`."""
Expand Down Expand Up @@ -177,6 +181,70 @@ def main_decorator(main_fn: ServerAppCallable) -> ServerAppCallable:

return main_decorator

def enter(self) -> Callable[[ServerAppCallable], ServerAppCallable]:
"""Return a decorator that registers the enter fn with the server app.
Examples
--------
>>> app = ServerApp()
>>>
>>> @app.enter()
>>> def enter(driver: Driver, context: Context) -> None:
>>> print("ServerApp enter running")
"""

def enter_decorator(enter_fn: ServerAppCallable) -> ServerAppCallable:
"""Register the enter fn with the ServerApp object."""
warn_preview_feature("ServerApp-register-enter-function")

# Register provided function with the ServerApp object
self._enter = enter_fn

# Return provided function unmodified
return enter_fn

return enter_decorator

def exit(self) -> Callable[[ServerAppCallable], ServerAppCallable]:
"""Return a decorator that registers the exit fn with the server app.
Examples
--------
>>> app = ServerApp()
>>>
>>> @app.exit()
>>> def exit(context: Context) -> None:
>>> print("ServerApp exit running")
"""

def exit_decorator(exit_fn: ServerAppCallable) -> ServerAppCallable:
"""Register the exit fn with the ServerApp object."""
warn_preview_feature("ServerApp-register-exit-function")

# Register provided function with the ServerApp object
self._exit = exit_fn

# Return provided function unmodified
return exit_fn

return exit_decorator


class LoadServerAppError(Exception):
"""Error when trying to load `ServerApp`."""


@contextmanager
def manage_server_app(
app: ServerApp, driver: Driver, context: Context
) -> Iterator[None]:
"""Manage the lifecycle of a ServerApp."""
# pylint: disable=protected-access
try:
if app._enter is not None:
app._enter(driver, context)
yield
finally:
if app._exit is not None:
app._exit(driver, context)
# pylint: enable=protected-access

0 comments on commit 7aa9d31

Please sign in to comment.