Skip to content

Commit

Permalink
feat(framework) Add end-to-end FAB delivery (#3852)
Browse files Browse the repository at this point in the history
Co-authored-by: Taner Topal <[email protected]>
Co-authored-by: Daniel J. Beutel <[email protected]>
  • Loading branch information
3 people authored Aug 16, 2024
1 parent 425af4f commit e366b3a
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 19 deletions.
23 changes: 17 additions & 6 deletions src/py/flwr/client/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
from cryptography.hazmat.primitives.asymmetric import ec
from grpc import RpcError

from flwr.cli.config_utils import get_fab_metadata
from flwr.cli.install import install_from_fab
from flwr.client.client import Client
from flwr.client.client_app import ClientApp, LoadClientAppError
from flwr.client.typing import ClientFnExt
Expand Down Expand Up @@ -339,7 +341,7 @@ def _on_backoff(retry_state: RetryState) -> None:
root_certificates,
authentication_keys,
) as conn:
receive, send, create_node, delete_node, get_run, _ = conn
receive, send, create_node, delete_node, get_run, get_fab = conn

# Register node when connecting the first time
if node_state is None:
Expand Down Expand Up @@ -406,9 +408,16 @@ def _on_backoff(retry_state: RetryState) -> None:
else:
runs[run_id] = Run(run_id, "", "", "", {})

run: Run = runs[run_id]
if get_fab is not None and run.fab_hash:
fab = get_fab(run.fab_hash)
install_from_fab(fab.content, flwr_path, True)
else:
fab = None

# Register context for this run
node_state.register_context(
run_id=run_id, run=runs[run_id], flwr_path=flwr_path
run_id=run_id, run=run, flwr_path=flwr_path
)

# Retrieve context for this run
Expand All @@ -423,10 +432,12 @@ def _on_backoff(retry_state: RetryState) -> None:
# Handle app loading and task message
try:
# Load ClientApp instance
run: Run = runs[run_id]
client_app: ClientApp = load_client_app_fn(
run.fab_id, run.fab_version
)
if fab:
fab_id, fab_version = get_fab_metadata(fab.content)
else:
fab_id, fab_version = run.fab_id, run.fab_version

client_app: ClientApp = load_client_app_fn(fab_id, fab_version)

# Execute ClientApp
reply_message = client_app(message=message, context=context)
Expand Down
3 changes: 1 addition & 2 deletions src/py/flwr/client/supernode/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
)

from flwr.common import EventType, event
from flwr.common.config import get_flwr_dir, parse_config_args
from flwr.common.config import parse_config_args
from flwr.common.constant import (
TRANSPORT_TYPE_GRPC_ADAPTER,
TRANSPORT_TYPE_GRPC_RERE,
Expand Down Expand Up @@ -73,7 +73,6 @@ def run_supernode() -> None:
max_retries=args.max_retries,
max_wait_time=args.max_wait_time,
node_config=parse_config_args([args.node_config]),
flwr_path=get_flwr_dir(args.flwr_dir),
)

# Graceful shutdown
Expand Down
20 changes: 18 additions & 2 deletions src/py/flwr/server/run_serverapp.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from pathlib import Path
from typing import Optional

from flwr.cli.config_utils import get_fab_metadata
from flwr.cli.install import install_from_fab
from flwr.common import Context, EventType, RecordSet, event
from flwr.common.config import (
get_flwr_dir,
Expand All @@ -36,6 +38,7 @@
CreateRunRequest,
CreateRunResponse,
)
from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611

from .driver import Driver
from .driver.grpc_driver import GrpcDriver
Expand Down Expand Up @@ -87,7 +90,8 @@ def _load() -> ServerApp:
log(DEBUG, "ServerApp finished running.")


def run_server_app() -> None: # pylint: disable=too-many-branches
# pylint: disable-next=too-many-branches,too-many-statements,too-many-locals
def run_server_app() -> None:
"""Run Flower server app."""
event(EventType.RUN_SERVER_APP_ENTER)

Expand Down Expand Up @@ -164,7 +168,19 @@ def run_server_app() -> None: # pylint: disable=too-many-branches
)
flwr_dir = get_flwr_dir(args.flwr_dir)
run_ = driver.run
app_path = str(get_project_dir(run_.fab_id, run_.fab_version, flwr_dir))
if run_.fab_hash:
fab_req = GetFabRequest(hash_str=run_.fab_hash)
# pylint: disable-next=W0212
fab_res: GetFabResponse = driver._stub.GetFab(fab_req)
if fab_res.fab.hash_str != run_.fab_hash:
raise ValueError("FAB hashes don't match.")

install_from_fab(fab_res.fab.content, flwr_dir, True)
fab_id, fab_version = get_fab_metadata(fab_res.fab.content)
else:
fab_id, fab_version = run_.fab_id, run_.fab_version

app_path = str(get_project_dir(fab_id, fab_version, flwr_dir))
config = get_project_config(app_path)
else:
# User provided `app_dir`, but not `--run-id`
Expand Down
17 changes: 8 additions & 9 deletions src/py/flwr/superexec/deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,19 @@
# ==============================================================================
"""Deployment engine executor."""

import hashlib
import subprocess
from logging import ERROR, INFO
from pathlib import Path
from typing import Optional

from typing_extensions import override

from flwr.cli.config_utils import get_fab_metadata
from flwr.cli.install import install_from_fab
from flwr.common.grpc import create_channel
from flwr.common.logger import log
from flwr.common.serde import user_config_to_proto
from flwr.common.typing import UserConfig
from flwr.common.serde import fab_to_proto, user_config_to_proto
from flwr.common.typing import Fab, UserConfig
from flwr.proto.driver_pb2 import CreateRunRequest # pylint: disable=E0611
from flwr.proto.driver_pb2_grpc import DriverStub
from flwr.server.driver.grpc_driver import DEFAULT_SERVER_ADDRESS_DRIVER
Expand Down Expand Up @@ -113,8 +113,7 @@ def _connect(self) -> None:

def _create_run(
self,
fab_id: str,
fab_version: str,
fab: Fab,
override_config: UserConfig,
) -> int:
if self.stub is None:
Expand All @@ -123,8 +122,7 @@ def _create_run(
assert self.stub is not None

req = CreateRunRequest(
fab_id=fab_id,
fab_version=fab_version,
fab=fab_to_proto(fab),
override_config=user_config_to_proto(override_config),
)
res = self.stub.CreateRun(request=req)
Expand All @@ -140,11 +138,12 @@ def start_run(
"""Start run using the Flower Deployment Engine."""
try:
# Install FAB to flwr dir
fab_version, fab_id = get_fab_metadata(fab_file)
install_from_fab(fab_file, None, True)

# Call SuperLink to create run
run_id: int = self._create_run(fab_id, fab_version, override_config)
run_id: int = self._create_run(
Fab(hashlib.sha256(fab_file).hexdigest(), fab_file), override_config
)
log(INFO, "Created run %s", str(run_id))

command = [
Expand Down

0 comments on commit e366b3a

Please sign in to comment.