Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(framework) Add end-to-end FAB delivery #3852

Merged
merged 85 commits into from
Aug 16, 2024
Merged
Show file tree
Hide file tree
Changes from 79 commits
Commits
Show all changes
85 commits
Select commit Hold shift + click to select a range
2d3398a
feat(framework:skip) Add Flower File Storage interface and disk based…
tanertopal Jun 16, 2024
8f4d7c2
feat(framework:skip) Add Flower File Storage interface and disk based…
tanertopal Jun 16, 2024
b06d6cd
Merge branch 'add_first_ffs_impl' of github.com:adap/flower into add_…
tanertopal Jun 16, 2024
5f57946
Fix type errors
tanertopal Jun 16, 2024
adeeb3c
Fix
tanertopal Jun 16, 2024
bf4ff86
Merge branch 'main' into add_first_ffs_impl
tanertopal Jun 16, 2024
6af4ae5
Merge branch 'main' into add_first_ffs_impl
danieljanes Jun 17, 2024
ea35bfb
feat(framework) Add necessary proto files for FAB delivery
charlesbvll Jul 18, 2024
c4186b5
Update deployment plugin
charlesbvll Jul 18, 2024
bc70b5f
Add necessary code changes
charlesbvll Jul 18, 2024
214ef0e
Fix imports
charlesbvll Jul 18, 2024
dd36e56
Fix mypy errors
charlesbvll Jul 18, 2024
8fcc3d6
Fix imports
charlesbvll Jul 18, 2024
ab76f27
Fix docstring
charlesbvll Jul 19, 2024
a8ab701
Fix errors
charlesbvll Jul 19, 2024
113b739
Fix imports
charlesbvll Jul 19, 2024
d565c87
Sort _all_
charlesbvll Jul 19, 2024
fb0de76
Merge branch 'main' into add_first_ffs_impl
charlesbvll Jul 19, 2024
99eb7e0
Merge branch 'main' into fab-delivery-protos
charlesbvll Jul 19, 2024
6837800
Merge branch 'add_first_ffs_impl' into fab-delivery-protos
charlesbvll Jul 19, 2024
6210c6c
Add FFS factory and use it
charlesbvll Jul 19, 2024
62a1433
Fix test
charlesbvll Jul 19, 2024
c2874ff
Fix imports
charlesbvll Jul 19, 2024
7e1ae8c
Fix imports
charlesbvll Jul 19, 2024
6e76e3b
Fix test
charlesbvll Jul 19, 2024
2ed64e0
Lots of changes
charlesbvll Jul 21, 2024
1a77241
Merge branch 'main' of https://github.com/adap/flower into fab-delive…
charlesbvll Jul 22, 2024
e61274b
Fix protos
charlesbvll Jul 22, 2024
755cf2b
Merge branch 'main' of https://github.com/adap/flower into fab-delive…
charlesbvll Jul 25, 2024
57c7020
Keep Run compatible
charlesbvll Jul 25, 2024
a5729c1
Compile protos
charlesbvll Jul 25, 2024
583a62b
Use both fab_id/version and fab_hash
charlesbvll Jul 25, 2024
8a62541
Fix imports
charlesbvll Jul 26, 2024
d99a5d1
Merge branch 'main' into fab-delivery-protos
charlesbvll Jul 26, 2024
0025418
Fix docstring
charlesbvll Jul 26, 2024
06d8fb7
Fixes
charlesbvll Jul 26, 2024
c96964f
Fix vce_api
charlesbvll Jul 26, 2024
7d8ade9
Fix mypy errors
charlesbvll Jul 26, 2024
22d3259
Remove mypy error
charlesbvll Jul 26, 2024
042328d
Merge branch 'main' into fab-delivery-protos
charlesbvll Jul 26, 2024
4ee42ed
Fix pylint errors
charlesbvll Jul 26, 2024
23f5c8e
Fix ruff error
charlesbvll Jul 26, 2024
945cfd9
Create base dir
charlesbvll Jul 26, 2024
9e76d5f
Add hash to Fab in superexec
charlesbvll Jul 26, 2024
fb75921
Merge branch 'main' into fab-delivery-protos
charlesbvll Jul 26, 2024
708ede6
Add debug strings
charlesbvll Jul 26, 2024
7c209dc
Merge branch 'main' into fab-delivery-protos
charlesbvll Jul 26, 2024
dadd9cc
Fix compatibility
charlesbvll Jul 26, 2024
f0477c2
Fix condition
charlesbvll Jul 26, 2024
0830a9f
Add debug strings
charlesbvll Jul 26, 2024
eb318f5
Change debug statements
charlesbvll Jul 26, 2024
d0cbf02
Fix debug
charlesbvll Jul 26, 2024
a541326
Fix condition
charlesbvll Jul 26, 2024
bb0c1f4
Try adding fab_hash
charlesbvll Jul 26, 2024
8294542
Try fixing message_handler
charlesbvll Jul 26, 2024
1332e3c
Try a fix
charlesbvll Jul 26, 2024
3f93955
Probable fix
charlesbvll Jul 26, 2024
6ece597
Merge branch 'main' into fab-delivery-protos
charlesbvll Jul 26, 2024
8066b4f
Merge branch 'main' into fab-delivery-protos
charlesbvll Jul 26, 2024
958db52
Merge branch 'main' into fab-delivery-protos
charlesbvll Jul 27, 2024
b433ddc
Merge branch 'main' into fab-delivery-protos
charlesbvll Aug 7, 2024
e739321
Fix mypy errors
charlesbvll Aug 7, 2024
2f41ad5
Merge branch 'main' into fab-delivery-protos
charlesbvll Aug 13, 2024
1c7300a
Remove unused code
charlesbvll Aug 13, 2024
f47a564
Fix serverapp
charlesbvll Aug 13, 2024
4899836
Fix serde functions
charlesbvll Aug 13, 2024
f4253cf
Remove pylint error
charlesbvll Aug 13, 2024
f567fcd
Merge branch 'main' into fab-delivery-protos
charlesbvll Aug 15, 2024
78d7aa5
Merge branch 'main' into fab-delivery-protos
charlesbvll Aug 15, 2024
bd54fcb
Merge branch 'main' into fab-delivery-protos
charlesbvll Aug 15, 2024
d2b763b
Update src/py/flwr/cli/run/run.py
charlesbvll Aug 15, 2024
9802bab
Fix config
charlesbvll Aug 15, 2024
c032f8c
Fix test
charlesbvll Aug 15, 2024
72b4a9d
Merge branch 'main' into fab-delivery-protos
charlesbvll Aug 15, 2024
2a09183
Fix tests
charlesbvll Aug 15, 2024
e385c79
Merge branch 'main' into fab-delivery-protos
charlesbvll Aug 16, 2024
ef02893
Merge branch 'main' into fab-delivery-protos
charlesbvll Aug 16, 2024
3b2a45f
Remove unused import
charlesbvll Aug 16, 2024
37fdb16
Merge branch 'main' into fab-delivery-protos
charlesbvll Aug 16, 2024
a54e139
Revert node_state changes
charlesbvll Aug 16, 2024
e24c03e
Fix order
charlesbvll Aug 16, 2024
6e7fc71
Remove unused optional
charlesbvll Aug 16, 2024
ca708ce
Use decalred variable
charlesbvll Aug 16, 2024
a7bd191
Merge branch 'main' into fab-delivery-protos
charlesbvll Aug 16, 2024
a5fd135
remove unused import
charlesbvll Aug 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 17 additions & 10 deletions src/py/flwr/client/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@
import time
from dataclasses import dataclass
from logging import ERROR, INFO, WARN
from pathlib import Path
from typing import Callable, ContextManager, Dict, Optional, Tuple, Type, Union

import grpc
from cryptography.hazmat.primitives.asymmetric import ec
from grpc import RpcError

from flwr.cli.config_utils import get_fab_config, get_fab_metadata
from flwr.client.client import Client
from flwr.client.client_app import ClientApp, LoadClientAppError
from flwr.client.typing import ClientFnExt
Expand Down Expand Up @@ -201,7 +201,6 @@ def start_client_internal(
] = None,
max_retries: Optional[int] = None,
max_wait_time: Optional[float] = None,
flwr_path: Optional[Path] = None,
) -> None:
"""Start a Flower client node which connects to a Flower server.

Expand Down Expand Up @@ -247,8 +246,6 @@ class `flwr.client.Client` (default: None)
The maximum duration before the client stops trying to
connect to the server in case of connection error.
If set to None, there is no limit to the total time.
flwr_path: Optional[Path] (default: None)
The fully resolved path containing installed Flower Apps.
"""
if insecure is None:
insecure = root_certificates is None
Expand Down Expand Up @@ -339,7 +336,7 @@ def _on_backoff(retry_state: RetryState) -> None:
root_certificates,
authentication_keys,
) as conn:
receive, send, create_node, delete_node, get_run, _ = conn
receive, send, create_node, delete_node, get_run, get_fab = conn

# Register node when connecting the first time
if node_state is None:
Expand Down Expand Up @@ -406,9 +403,17 @@ def _on_backoff(retry_state: RetryState) -> None:
else:
runs[run_id] = Run(run_id, "", "", "", {})

run: Run = runs[run_id]
if get_fab is not None and run.fab_hash:
fab = get_fab(run.fab_hash)
else:
fab = None

# Register context for this run
node_state.register_context(
run_id=run_id, run=runs[run_id], flwr_path=flwr_path
run_id=run_id,
default_config=get_fab_config(fab.content) if fab else {},
run=runs[run_id],
)

# Retrieve context for this run
Expand All @@ -423,10 +428,12 @@ def _on_backoff(retry_state: RetryState) -> None:
# Handle app loading and task message
try:
# Load ClientApp instance
run: Run = runs[run_id]
client_app: ClientApp = load_client_app_fn(
run.fab_id, run.fab_version
)
if fab:
fab_id, fab_version = get_fab_metadata(fab.content)
else:
fab_id, fab_version = run.fab_id, run.fab_version

client_app: ClientApp = load_client_app_fn(fab_id, fab_version)

# Execute ClientApp
reply_message = client_app(message=message, context=context)
Expand Down
8 changes: 5 additions & 3 deletions src/py/flwr/client/node_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from typing import Dict, Optional

from flwr.common import Context, RecordSet
from flwr.common.config import get_fused_config, get_fused_config_from_dir
from flwr.common.config import fuse_dicts, get_fused_config_from_dir
from flwr.common.typing import Run, UserConfig


Expand All @@ -47,8 +47,8 @@ def __init__(
def register_context(
self,
run_id: int,
default_config: UserConfig,
run: Optional[Run] = None,
flwr_path: Optional[Path] = None,
app_dir: Optional[str] = None,
) -> None:
"""Register new run context for this node."""
Expand All @@ -66,7 +66,9 @@ def register_context(
raise ValueError("The specified `app_dir` must be a directory.")
else:
# Load from .fab
initial_run_config = get_fused_config(run, flwr_path) if run else {}
initial_run_config = (
fuse_dicts(default_config, run.override_config) if run else {}
)
self.run_infos[run_id] = RunInfo(
initial_run_config=initial_run_config,
context=Context(
Expand Down
2 changes: 1 addition & 1 deletion src/py/flwr/client/node_state_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def test_multirun_in_node_state() -> None:
run_id = task.run_id

# Register
node_state.register_context(run_id=run_id)
node_state.register_context(run_id=run_id, default_config={})

# Get run state
context = node_state.retrieve_context(run_id=run_id)
Expand Down
3 changes: 1 addition & 2 deletions src/py/flwr/client/supernode/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
)

from flwr.common import EventType, event
from flwr.common.config import get_flwr_dir, parse_config_args
from flwr.common.config import parse_config_args
from flwr.common.constant import (
TRANSPORT_TYPE_GRPC_ADAPTER,
TRANSPORT_TYPE_GRPC_RERE,
Expand Down Expand Up @@ -73,7 +73,6 @@ def run_supernode() -> None:
max_retries=args.max_retries,
max_wait_time=args.max_wait_time,
node_config=parse_config_args([args.node_config]),
flwr_path=get_flwr_dir(args.flwr_dir),
)

# Graceful shutdown
Expand Down
37 changes: 26 additions & 11 deletions src/py/flwr/server/run_serverapp.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@
from pathlib import Path
from typing import Optional

from flwr.cli.config_utils import get_fab_config
from flwr.common import Context, EventType, RecordSet, event
from flwr.common.config import (
fuse_dicts,
get_flwr_dir,
get_fused_config_from_dir,
get_metadata_from_config,
Expand All @@ -36,6 +38,7 @@
CreateRunRequest,
CreateRunResponse,
)
from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611

from .driver import Driver
from .driver.grpc_driver import GrpcDriver
Expand All @@ -46,7 +49,7 @@

def run(
driver: Driver,
server_app_dir: str,
server_app_dir: Optional[str],
server_app_run_config: UserConfig,
server_app_attr: Optional[str] = None,
loaded_server_app: Optional[ServerApp] = None,
Expand Down Expand Up @@ -87,7 +90,8 @@ def _load() -> ServerApp:
log(DEBUG, "ServerApp finished running.")


def run_server_app() -> None: # pylint: disable=too-many-branches
# pylint: disable-next=too-many-branches,too-many-statements,too-many-locals
def run_server_app() -> None:
"""Run Flower server app."""
event(EventType.RUN_SERVER_APP_ENTER)

Expand Down Expand Up @@ -162,10 +166,23 @@ def run_server_app() -> None: # pylint: disable=too-many-branches
driver_service_address=args.superlink,
root_certificates=root_certificates,
)
flwr_dir = get_flwr_dir(args.flwr_dir)
run_ = driver.run
app_path = str(get_project_dir(run_.fab_id, run_.fab_version, flwr_dir))
config = get_project_config(app_path)
if run_.fab_hash:
fab_req = GetFabRequest(hash_str=run_.fab_hash)
# pylint: disable-next=W0212
fab_res: GetFabResponse = driver._stub.GetFab(fab_req)
if fab_res.fab.hash_str != run_.fab_hash:
raise ValueError("FAB hashes don't match.")

config = get_fab_config(fab_res.fab.content)
server_app_run_config = fuse_dicts(config, run_.override_config)
else:
flwr_dir = get_flwr_dir(args.flwr_dir)
app_path = str(get_project_dir(run_.fab_id, run_.fab_version, flwr_dir))
config = get_project_config(app_path)
server_app_run_config = get_fused_config_from_dir(
Path(app_path), driver.run.override_config
)
else:
# User provided `app_dir`, but not `--run-id`
# Create run if run_id is not provided
Expand All @@ -184,13 +201,11 @@ def run_server_app() -> None: # pylint: disable=too-many-branches
# Overwrite driver._run_id
driver._run_id = res.run_id # pylint: disable=W0212

# Obtain server app reference and the run config
server_app_attr = config["tool"]["flwr"]["app"]["components"]["serverapp"]
server_app_run_config = get_fused_config_from_dir(
Path(app_path), driver.run.override_config
)
server_app_run_config = get_fused_config_from_dir(
Path(app_path), driver.run.override_config
)

log(DEBUG, "Flower will load ServerApp `%s` in %s", server_app_attr, app_path)
server_app_attr = config["tool"]["flwr"]["app"]["components"]["serverapp"]

log(
DEBUG,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def _create_message_and_context() -> Tuple[Message, Context, float]:

# Construct NodeState and retrieve context
node_state = NodeState(node_id=run_id, node_config={PARTITION_ID_KEY: str(0)})
node_state.register_context(run_id=run_id)
node_state.register_context(run_id=run_id, default_config={})
context = node_state.retrieve_context(run_id=run_id)

# Expected output
Expand Down
2 changes: 1 addition & 1 deletion src/py/flwr/server/superlink/fleet/vce/vce_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def _register_node_states(

# Pre-register Context objects
node_states[node_id].register_context(
run_id=run.run_id, run=run, app_dir=app_dir
run_id=run.run_id, default_config={}, run=run, app_dir=app_dir
)

return node_states
Expand Down
2 changes: 1 addition & 1 deletion src/py/flwr/simulation/ray_transport/ray_client_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def _submit_job(self, message: Message, timeout: Optional[float]) -> Message:
run_id = message.metadata.run_id

# Register state
self.proxy_state.register_context(run_id=run_id)
self.proxy_state.register_context(run_id=run_id, default_config={})

# Retrieve context
context = self.proxy_state.retrieve_context(run_id=run_id)
Expand Down
4 changes: 2 additions & 2 deletions src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def test_cid_consistency_all_submit_first_run_consistency() -> None:
shuffle(proxies)
for prox in proxies:
# Register state
prox.proxy_state.register_context(run_id=run_id)
prox.proxy_state.register_context(run_id=run_id, default_config={})
# Retrieve state
state = prox.proxy_state.retrieve_context(run_id=run_id)

Expand Down Expand Up @@ -229,7 +229,7 @@ def _load_app() -> ClientApp:
),
)
# register and retrieve context
node_states[node_id].register_context(run_id=run_id)
node_states[node_id].register_context(run_id=run_id, default_config={})
context = node_states[node_id].retrieve_context(run_id=run_id)
partition_id_str = str(context.node_config[PARTITION_ID_KEY])
pool.submit_client_job(
Expand Down
17 changes: 8 additions & 9 deletions src/py/flwr/superexec/deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,19 @@
# ==============================================================================
"""Deployment engine executor."""

import hashlib
import subprocess
from logging import ERROR, INFO
from pathlib import Path
from typing import Optional

from typing_extensions import override

from flwr.cli.config_utils import get_fab_metadata
from flwr.cli.install import install_from_fab
from flwr.common.grpc import create_channel
from flwr.common.logger import log
from flwr.common.serde import user_config_to_proto
from flwr.common.typing import UserConfig
from flwr.common.serde import fab_to_proto, user_config_to_proto
from flwr.common.typing import Fab, UserConfig
from flwr.proto.driver_pb2 import CreateRunRequest # pylint: disable=E0611
from flwr.proto.driver_pb2_grpc import DriverStub
from flwr.server.driver.grpc_driver import DEFAULT_SERVER_ADDRESS_DRIVER
Expand Down Expand Up @@ -113,8 +113,7 @@ def _connect(self) -> None:

def _create_run(
self,
fab_id: str,
fab_version: str,
fab: Fab,
override_config: UserConfig,
) -> int:
if self.stub is None:
Expand All @@ -123,8 +122,7 @@ def _create_run(
assert self.stub is not None

req = CreateRunRequest(
fab_id=fab_id,
fab_version=fab_version,
fab=fab_to_proto(fab),
override_config=user_config_to_proto(override_config),
)
res = self.stub.CreateRun(request=req)
Expand All @@ -140,11 +138,12 @@ def start_run(
"""Start run using the Flower Deployment Engine."""
try:
# Install FAB to flwr dir
fab_version, fab_id = get_fab_metadata(fab_file)
install_from_fab(fab_file, None, True)

# Call SuperLink to create run
run_id: int = self._create_run(fab_id, fab_version, override_config)
run_id: int = self._create_run(
Fab(hashlib.sha256(fab_file).hexdigest(), fab_file), override_config
)
log(INFO, "Created run %s", str(run_id))

command = [
Expand Down