From e366b3a8aa060b855e401f5902639e32dbe517a5 Mon Sep 17 00:00:00 2001 From: Charles Beauville Date: Fri, 16 Aug 2024 18:04:53 +0200 Subject: [PATCH] feat(framework) Add end-to-end FAB delivery (#3852) Co-authored-by: Taner Topal Co-authored-by: Daniel J. Beutel --- src/py/flwr/client/app.py | 23 +++++++++++++++++------ src/py/flwr/client/supernode/app.py | 3 +-- src/py/flwr/server/run_serverapp.py | 20 ++++++++++++++++++-- src/py/flwr/superexec/deployment.py | 17 ++++++++--------- 4 files changed, 44 insertions(+), 19 deletions(-) diff --git a/src/py/flwr/client/app.py b/src/py/flwr/client/app.py index ef56c45939b4..bc42075aa9e5 100644 --- a/src/py/flwr/client/app.py +++ b/src/py/flwr/client/app.py @@ -26,6 +26,8 @@ from cryptography.hazmat.primitives.asymmetric import ec from grpc import RpcError +from flwr.cli.config_utils import get_fab_metadata +from flwr.cli.install import install_from_fab from flwr.client.client import Client from flwr.client.client_app import ClientApp, LoadClientAppError from flwr.client.typing import ClientFnExt @@ -339,7 +341,7 @@ def _on_backoff(retry_state: RetryState) -> None: root_certificates, authentication_keys, ) as conn: - receive, send, create_node, delete_node, get_run, _ = conn + receive, send, create_node, delete_node, get_run, get_fab = conn # Register node when connecting the first time if node_state is None: @@ -406,9 +408,16 @@ def _on_backoff(retry_state: RetryState) -> None: else: runs[run_id] = Run(run_id, "", "", "", {}) + run: Run = runs[run_id] + if get_fab is not None and run.fab_hash: + fab = get_fab(run.fab_hash) + install_from_fab(fab.content, flwr_path, True) + else: + fab = None + # Register context for this run node_state.register_context( - run_id=run_id, run=runs[run_id], flwr_path=flwr_path + run_id=run_id, run=run, flwr_path=flwr_path ) # Retrieve context for this run @@ -423,10 +432,12 @@ def _on_backoff(retry_state: RetryState) -> None: # Handle app loading and task message try: # Load ClientApp instance - run: Run = runs[run_id] - client_app: ClientApp = load_client_app_fn( - run.fab_id, run.fab_version - ) + if fab: + fab_id, fab_version = get_fab_metadata(fab.content) + else: + fab_id, fab_version = run.fab_id, run.fab_version + + client_app: ClientApp = load_client_app_fn(fab_id, fab_version) # Execute ClientApp reply_message = client_app(message=message, context=context) diff --git a/src/py/flwr/client/supernode/app.py b/src/py/flwr/client/supernode/app.py index d0928f8201fa..4370d7d1219d 100644 --- a/src/py/flwr/client/supernode/app.py +++ b/src/py/flwr/client/supernode/app.py @@ -28,7 +28,7 @@ ) from flwr.common import EventType, event -from flwr.common.config import get_flwr_dir, parse_config_args +from flwr.common.config import parse_config_args from flwr.common.constant import ( TRANSPORT_TYPE_GRPC_ADAPTER, TRANSPORT_TYPE_GRPC_RERE, @@ -73,7 +73,6 @@ def run_supernode() -> None: max_retries=args.max_retries, max_wait_time=args.max_wait_time, node_config=parse_config_args([args.node_config]), - flwr_path=get_flwr_dir(args.flwr_dir), ) # Graceful shutdown diff --git a/src/py/flwr/server/run_serverapp.py b/src/py/flwr/server/run_serverapp.py index 76eae30330d9..8f67c917c8ed 100644 --- a/src/py/flwr/server/run_serverapp.py +++ b/src/py/flwr/server/run_serverapp.py @@ -21,6 +21,8 @@ from pathlib import Path from typing import Optional +from flwr.cli.config_utils import get_fab_metadata +from flwr.cli.install import install_from_fab from flwr.common import Context, EventType, RecordSet, event from flwr.common.config import ( get_flwr_dir, @@ -36,6 +38,7 @@ CreateRunRequest, CreateRunResponse, ) +from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611 from .driver import Driver from .driver.grpc_driver import GrpcDriver @@ -87,7 +90,8 @@ def _load() -> ServerApp: log(DEBUG, "ServerApp finished running.") -def run_server_app() -> None: # pylint: disable=too-many-branches +# pylint: disable-next=too-many-branches,too-many-statements,too-many-locals +def run_server_app() -> None: """Run Flower server app.""" event(EventType.RUN_SERVER_APP_ENTER) @@ -164,7 +168,19 @@ def run_server_app() -> None: # pylint: disable=too-many-branches ) flwr_dir = get_flwr_dir(args.flwr_dir) run_ = driver.run - app_path = str(get_project_dir(run_.fab_id, run_.fab_version, flwr_dir)) + if run_.fab_hash: + fab_req = GetFabRequest(hash_str=run_.fab_hash) + # pylint: disable-next=W0212 + fab_res: GetFabResponse = driver._stub.GetFab(fab_req) + if fab_res.fab.hash_str != run_.fab_hash: + raise ValueError("FAB hashes don't match.") + + install_from_fab(fab_res.fab.content, flwr_dir, True) + fab_id, fab_version = get_fab_metadata(fab_res.fab.content) + else: + fab_id, fab_version = run_.fab_id, run_.fab_version + + app_path = str(get_project_dir(fab_id, fab_version, flwr_dir)) config = get_project_config(app_path) else: # User provided `app_dir`, but not `--run-id` diff --git a/src/py/flwr/superexec/deployment.py b/src/py/flwr/superexec/deployment.py index fd09b512a52c..2354e047a1ec 100644 --- a/src/py/flwr/superexec/deployment.py +++ b/src/py/flwr/superexec/deployment.py @@ -14,6 +14,7 @@ # ============================================================================== """Deployment engine executor.""" +import hashlib import subprocess from logging import ERROR, INFO from pathlib import Path @@ -21,12 +22,11 @@ from typing_extensions import override -from flwr.cli.config_utils import get_fab_metadata from flwr.cli.install import install_from_fab from flwr.common.grpc import create_channel from flwr.common.logger import log -from flwr.common.serde import user_config_to_proto -from flwr.common.typing import UserConfig +from flwr.common.serde import fab_to_proto, user_config_to_proto +from flwr.common.typing import Fab, UserConfig from flwr.proto.driver_pb2 import CreateRunRequest # pylint: disable=E0611 from flwr.proto.driver_pb2_grpc import DriverStub from flwr.server.driver.grpc_driver import DEFAULT_SERVER_ADDRESS_DRIVER @@ -113,8 +113,7 @@ def _connect(self) -> None: def _create_run( self, - fab_id: str, - fab_version: str, + fab: Fab, override_config: UserConfig, ) -> int: if self.stub is None: @@ -123,8 +122,7 @@ def _create_run( assert self.stub is not None req = CreateRunRequest( - fab_id=fab_id, - fab_version=fab_version, + fab=fab_to_proto(fab), override_config=user_config_to_proto(override_config), ) res = self.stub.CreateRun(request=req) @@ -140,11 +138,12 @@ def start_run( """Start run using the Flower Deployment Engine.""" try: # Install FAB to flwr dir - fab_version, fab_id = get_fab_metadata(fab_file) install_from_fab(fab_file, None, True) # Call SuperLink to create run - run_id: int = self._create_run(fab_id, fab_version, override_config) + run_id: int = self._create_run( + Fab(hashlib.sha256(fab_file).hexdigest(), fab_file), override_config + ) log(INFO, "Created run %s", str(run_id)) command = [