Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(framework) Add end-to-end FAB delivery #3852

Merged
merged 85 commits into from
Aug 16, 2024
Merged
Show file tree
Hide file tree
Changes from 70 commits
Commits
Show all changes
85 commits
Select commit Hold shift + click to select a range
2d3398a
feat(framework:skip) Add Flower File Storage interface and disk based…
tanertopal Jun 16, 2024
8f4d7c2
feat(framework:skip) Add Flower File Storage interface and disk based…
tanertopal Jun 16, 2024
b06d6cd
Merge branch 'add_first_ffs_impl' of github.com:adap/flower into add_…
tanertopal Jun 16, 2024
5f57946
Fix type errors
tanertopal Jun 16, 2024
adeeb3c
Fix
tanertopal Jun 16, 2024
bf4ff86
Merge branch 'main' into add_first_ffs_impl
tanertopal Jun 16, 2024
6af4ae5
Merge branch 'main' into add_first_ffs_impl
danieljanes Jun 17, 2024
ea35bfb
feat(framework) Add necessary proto files for FAB delivery
charlesbvll Jul 18, 2024
c4186b5
Update deployment plugin
charlesbvll Jul 18, 2024
bc70b5f
Add necessary code changes
charlesbvll Jul 18, 2024
214ef0e
Fix imports
charlesbvll Jul 18, 2024
dd36e56
Fix mypy errors
charlesbvll Jul 18, 2024
8fcc3d6
Fix imports
charlesbvll Jul 18, 2024
ab76f27
Fix docstring
charlesbvll Jul 19, 2024
a8ab701
Fix errors
charlesbvll Jul 19, 2024
113b739
Fix imports
charlesbvll Jul 19, 2024
d565c87
Sort _all_
charlesbvll Jul 19, 2024
fb0de76
Merge branch 'main' into add_first_ffs_impl
charlesbvll Jul 19, 2024
99eb7e0
Merge branch 'main' into fab-delivery-protos
charlesbvll Jul 19, 2024
6837800
Merge branch 'add_first_ffs_impl' into fab-delivery-protos
charlesbvll Jul 19, 2024
6210c6c
Add FFS factory and use it
charlesbvll Jul 19, 2024
62a1433
Fix test
charlesbvll Jul 19, 2024
c2874ff
Fix imports
charlesbvll Jul 19, 2024
7e1ae8c
Fix imports
charlesbvll Jul 19, 2024
6e76e3b
Fix test
charlesbvll Jul 19, 2024
2ed64e0
Lots of changes
charlesbvll Jul 21, 2024
1a77241
Merge branch 'main' of https://github.com/adap/flower into fab-delive…
charlesbvll Jul 22, 2024
e61274b
Fix protos
charlesbvll Jul 22, 2024
755cf2b
Merge branch 'main' of https://github.com/adap/flower into fab-delive…
charlesbvll Jul 25, 2024
57c7020
Keep Run compatible
charlesbvll Jul 25, 2024
a5729c1
Compile protos
charlesbvll Jul 25, 2024
583a62b
Use both fab_id/version and fab_hash
charlesbvll Jul 25, 2024
8a62541
Fix imports
charlesbvll Jul 26, 2024
d99a5d1
Merge branch 'main' into fab-delivery-protos
charlesbvll Jul 26, 2024
0025418
Fix docstring
charlesbvll Jul 26, 2024
06d8fb7
Fixes
charlesbvll Jul 26, 2024
c96964f
Fix vce_api
charlesbvll Jul 26, 2024
7d8ade9
Fix mypy errors
charlesbvll Jul 26, 2024
22d3259
Remove mypy error
charlesbvll Jul 26, 2024
042328d
Merge branch 'main' into fab-delivery-protos
charlesbvll Jul 26, 2024
4ee42ed
Fix pylint errors
charlesbvll Jul 26, 2024
23f5c8e
Fix ruff error
charlesbvll Jul 26, 2024
945cfd9
Create base dir
charlesbvll Jul 26, 2024
9e76d5f
Add hash to Fab in superexec
charlesbvll Jul 26, 2024
fb75921
Merge branch 'main' into fab-delivery-protos
charlesbvll Jul 26, 2024
708ede6
Add debug strings
charlesbvll Jul 26, 2024
7c209dc
Merge branch 'main' into fab-delivery-protos
charlesbvll Jul 26, 2024
dadd9cc
Fix compatibility
charlesbvll Jul 26, 2024
f0477c2
Fix condition
charlesbvll Jul 26, 2024
0830a9f
Add debug strings
charlesbvll Jul 26, 2024
eb318f5
Change debug statements
charlesbvll Jul 26, 2024
d0cbf02
Fix debug
charlesbvll Jul 26, 2024
a541326
Fix condition
charlesbvll Jul 26, 2024
bb0c1f4
Try adding fab_hash
charlesbvll Jul 26, 2024
8294542
Try fixing message_handler
charlesbvll Jul 26, 2024
1332e3c
Try a fix
charlesbvll Jul 26, 2024
3f93955
Probable fix
charlesbvll Jul 26, 2024
6ece597
Merge branch 'main' into fab-delivery-protos
charlesbvll Jul 26, 2024
8066b4f
Merge branch 'main' into fab-delivery-protos
charlesbvll Jul 26, 2024
958db52
Merge branch 'main' into fab-delivery-protos
charlesbvll Jul 27, 2024
b433ddc
Merge branch 'main' into fab-delivery-protos
charlesbvll Aug 7, 2024
e739321
Fix mypy errors
charlesbvll Aug 7, 2024
2f41ad5
Merge branch 'main' into fab-delivery-protos
charlesbvll Aug 13, 2024
1c7300a
Remove unused code
charlesbvll Aug 13, 2024
f47a564
Fix serverapp
charlesbvll Aug 13, 2024
4899836
Fix serde functions
charlesbvll Aug 13, 2024
f4253cf
Remove pylint error
charlesbvll Aug 13, 2024
f567fcd
Merge branch 'main' into fab-delivery-protos
charlesbvll Aug 15, 2024
78d7aa5
Merge branch 'main' into fab-delivery-protos
charlesbvll Aug 15, 2024
bd54fcb
Merge branch 'main' into fab-delivery-protos
charlesbvll Aug 15, 2024
d2b763b
Update src/py/flwr/cli/run/run.py
charlesbvll Aug 15, 2024
9802bab
Fix config
charlesbvll Aug 15, 2024
c032f8c
Fix test
charlesbvll Aug 15, 2024
72b4a9d
Merge branch 'main' into fab-delivery-protos
charlesbvll Aug 15, 2024
2a09183
Fix tests
charlesbvll Aug 15, 2024
e385c79
Merge branch 'main' into fab-delivery-protos
charlesbvll Aug 16, 2024
ef02893
Merge branch 'main' into fab-delivery-protos
charlesbvll Aug 16, 2024
3b2a45f
Remove unused import
charlesbvll Aug 16, 2024
37fdb16
Merge branch 'main' into fab-delivery-protos
charlesbvll Aug 16, 2024
a54e139
Revert node_state changes
charlesbvll Aug 16, 2024
e24c03e
Fix order
charlesbvll Aug 16, 2024
6e7fc71
Remove unused optional
charlesbvll Aug 16, 2024
ca708ce
Use decalred variable
charlesbvll Aug 16, 2024
a7bd191
Merge branch 'main' into fab-delivery-protos
charlesbvll Aug 16, 2024
a5fd135
remove unused import
charlesbvll Aug 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/py/flwr/cli/run/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,8 @@ def on_channel_state_change(channel_connectivity: str) -> None:
content = fab_path.read_bytes()
fab = Fab(hashlib.sha256(content).hexdigest(), content)

fab = Fab("", Path(fab_path).read_bytes())

req = StartRunRequest(
fab=fab_to_proto(fab),
override_config=user_config_to_proto(
Expand Down
29 changes: 18 additions & 11 deletions src/py/flwr/client/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@
import time
from dataclasses import dataclass
from logging import ERROR, INFO, WARN
from pathlib import Path
from typing import Callable, ContextManager, Dict, Optional, Tuple, Type, Union

from cryptography.hazmat.primitives.asymmetric import ec
from grpc import RpcError

from flwr.cli.config_utils import get_fab_config, get_fab_metadata
from flwr.client.client import Client
from flwr.client.client_app import ClientApp, LoadClientAppError
from flwr.client.typing import ClientFnExt
Expand Down Expand Up @@ -195,7 +195,6 @@ def _start_client_internal(
] = None,
max_retries: Optional[int] = None,
max_wait_time: Optional[float] = None,
flwr_path: Optional[Path] = None,
) -> None:
"""Start a Flower client node which connects to a Flower server.

Expand Down Expand Up @@ -241,8 +240,6 @@ class `flwr.client.Client` (default: None)
The maximum duration before the client stops trying to
connect to the server in case of connection error.
If set to None, there is no limit to the total time.
flwr_path: Optional[Path] (default: None)
The fully resolved path containing installed Flower Apps.
"""
if insecure is None:
insecure = root_certificates is None
Expand Down Expand Up @@ -333,7 +330,7 @@ def _on_backoff(retry_state: RetryState) -> None:
root_certificates,
authentication_keys,
) as conn:
receive, send, create_node, delete_node, get_run, _ = conn
receive, send, create_node, delete_node, get_run, get_fab = conn

# Register node when connecting the first time
if node_state is None:
Expand Down Expand Up @@ -398,11 +395,19 @@ def _on_backoff(retry_state: RetryState) -> None:
runs[run_id] = get_run(run_id)
# If get_run is None, i.e., in grpc-bidi mode
else:
runs[run_id] = Run(run_id, "", "", {})
runs[run_id] = Run(run_id, "", "", "", {})

run: Run = runs[run_id]
if get_fab is not None and run.fab_hash:
fab = get_fab(run.fab_hash)
else:
fab = None

# Register context for this run
node_state.register_context(
run_id=run_id, run=runs[run_id], flwr_path=flwr_path
run_id=run_id,
default_config=get_fab_config(fab.content) if fab else {},
run=runs[run_id],
)

# Retrieve context for this run
Expand All @@ -417,10 +422,12 @@ def _on_backoff(retry_state: RetryState) -> None:
# Handle app loading and task message
try:
# Load ClientApp instance
run: Run = runs[run_id]
client_app: ClientApp = load_client_app_fn(
run.fab_id, run.fab_version
)
if fab:
fab_id, fab_version = get_fab_metadata(fab.content)
else:
fab_id, fab_version = run.fab_id, run.fab_version

client_app: ClientApp = load_client_app_fn(fab_id, fab_version)

# Execute ClientApp
reply_message = client_app(message=message, context=context)
Expand Down
10 changes: 9 additions & 1 deletion src/py/flwr/client/grpc_rere_client/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
user_config_from_proto,
)
from flwr.common.typing import Fab, Run
from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
CreateNodeRequest,
DeleteNodeRequest,
Expand Down Expand Up @@ -286,12 +287,19 @@ def get_run(run_id: int) -> Run:
run_id,
get_run_response.run.fab_id,
get_run_response.run.fab_version,
get_run_response.run.fab_hash,
user_config_from_proto(get_run_response.run.override_config),
)

def get_fab(fab_hash: str) -> Fab:
# Call FleetAPI
raise NotImplementedError
get_fab_request = GetFabRequest(hash_str=fab_hash)
get_fab_response: GetFabResponse = retry_invoker.invoke(
stub.GetFab,
request=get_fab_request,
)

return Fab(get_fab_response.fab.hash_str, get_fab_response.fab.content)

try:
# Yield methods
Expand Down
8 changes: 5 additions & 3 deletions src/py/flwr/client/node_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from typing import Dict, Optional

from flwr.common import Context, RecordSet
from flwr.common.config import get_fused_config, get_fused_config_from_dir
from flwr.common.config import fuse_dicts, get_fused_config_from_dir
from flwr.common.typing import Run, UserConfig


Expand All @@ -47,8 +47,8 @@ def __init__(
def register_context(
self,
run_id: int,
default_config: UserConfig,
run: Optional[Run] = None,
flwr_path: Optional[Path] = None,
app_dir: Optional[str] = None,
) -> None:
"""Register new run context for this node."""
Expand All @@ -66,7 +66,9 @@ def register_context(
raise ValueError("The specified `app_dir` must be a directory.")
else:
# Load from .fab
initial_run_config = get_fused_config(run, flwr_path) if run else {}
initial_run_config = (
fuse_dicts(default_config, run.override_config) if run else {}
)
self.run_infos[run_id] = RunInfo(
initial_run_config=initial_run_config,
context=Context(
Expand Down
2 changes: 1 addition & 1 deletion src/py/flwr/client/node_state_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def test_multirun_in_node_state() -> None:
run_id = task.run_id

# Register
node_state.register_context(run_id=run_id)
node_state.register_context(run_id=run_id, default_config={})

# Get run state
context = node_state.retrieve_context(run_id=run_id)
Expand Down
19 changes: 16 additions & 3 deletions src/py/flwr/client/rest_client/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
user_config_from_proto,
)
from flwr.common.typing import Fab, Run
from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
CreateNodeRequest,
CreateNodeResponse,
Expand Down Expand Up @@ -74,6 +75,7 @@
PATH_PUSH_TASK_RES: str = "api/v0/fleet/push-task-res"
PATH_PING: str = "api/v0/fleet/ping"
PATH_GET_RUN: str = "/api/v0/fleet/get-run"
PATH_GET_FAB: str = "/api/v0/fleet/get-fab"

T = TypeVar("T", bound=GrpcMessage)

Expand Down Expand Up @@ -358,18 +360,29 @@ def get_run(run_id: int) -> Run:
# Send the request
res = _request(req, GetRunResponse, PATH_GET_RUN)
if res is None:
return Run(run_id, "", "", {})
return Run(run_id, "", "", "", {})

return Run(
run_id,
res.run.fab_id,
res.run.fab_version,
res.run.fab_hash,
user_config_from_proto(res.run.override_config),
)

def get_fab(fab_hash: str) -> Fab:
# Call FleetAPI
raise NotImplementedError
# Construct the request
req = GetFabRequest(hash_str=fab_hash)

# Send the request
res = _request(req, GetFabResponse, PATH_GET_FAB)
if res is None:
return Fab("", b"")

return Fab(
res.fab.hash_str,
res.fab.content,
)

try:
# Yield methods
Expand Down
1 change: 0 additions & 1 deletion src/py/flwr/client/supernode/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ def run_supernode() -> None:
max_retries=args.max_retries,
max_wait_time=args.max_wait_time,
node_config=parse_config_args([args.node_config]),
flwr_path=get_flwr_dir(args.flwr_dir),
)

# Graceful shutdown
Expand Down
2 changes: 1 addition & 1 deletion src/py/flwr/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from flwr.cli.config_utils import validate_fields
from flwr.common.constant import APP_DIR, FAB_CONFIG_FILE, FLWR_HOME
from flwr.common.typing import Run, UserConfig, UserConfigValue
from flwr.common.typing import UserConfig, UserConfigValue


def get_flwr_dir(provided_path: Optional[str] = None) -> Path:
Expand Down
3 changes: 2 additions & 1 deletion src/py/flwr/common/serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -851,7 +851,7 @@ def run_to_proto(run: typing.Run) -> ProtoRun:
fab_id=run.fab_id,
fab_version=run.fab_version,
override_config=user_config_to_proto(run.override_config),
fab_hash="",
fab_hash=run.fab_hash,
)
return proto

Expand All @@ -862,6 +862,7 @@ def run_from_proto(run_proto: ProtoRun) -> typing.Run:
run_id=run_proto.run_id,
fab_id=run_proto.fab_id,
fab_version=run_proto.fab_version,
fab_hash=run_proto.fab_hash,
override_config=user_config_from_proto(run_proto.override_config),
)
return run
Expand Down
1 change: 1 addition & 0 deletions src/py/flwr/common/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ class Run:
run_id: int
fab_id: str
fab_version: str
fab_hash: str
override_config: UserConfig


Expand Down
15 changes: 15 additions & 0 deletions src/py/flwr/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

from flwr.common import GRPC_MAX_MESSAGE_LENGTH, EventType, event
from flwr.common.address import parse_address
from flwr.common.config import get_flwr_dir
from flwr.common.constant import (
MISSING_EXTRA_REST,
TRANSPORT_TYPE_GRPC_ADAPTER,
Expand All @@ -57,6 +58,7 @@
from .server_config import ServerConfig
from .strategy import Strategy
from .superlink.driver.driver_grpc import run_driver_api_grpc
from .superlink.ffs.ffs_factory import FfsFactory
from .superlink.fleet.grpc_adapter.grpc_adapter_servicer import GrpcAdapterServicer
from .superlink.fleet.grpc_bidi.grpc_server import (
generic_create_grpc_server,
Expand All @@ -72,6 +74,7 @@
ADDRESS_FLEET_API_REST = "0.0.0.0:9093"

DATABASE = ":flwr-in-memory-state:"
BASE_DIR = get_flwr_dir() / "ffs"


def start_server( # pylint: disable=too-many-arguments,too-many-locals
Expand Down Expand Up @@ -211,10 +214,14 @@ def run_superlink() -> None:
# Initialize StateFactory
state_factory = StateFactory(args.database)

# Initialize StateFactory
ffs_factory = FfsFactory(args.base_dir)

# Start Driver API
driver_server: grpc.Server = run_driver_api_grpc(
address=driver_address,
state_factory=state_factory,
ffs_factory=ffs_factory,
certificates=certificates,
)

Expand Down Expand Up @@ -294,6 +301,7 @@ def run_superlink() -> None:
fleet_server = _run_fleet_api_grpc_rere(
address=fleet_address,
state_factory=state_factory,
ffs_factory=ffs_factory,
certificates=certificates,
interceptors=interceptors,
)
Expand Down Expand Up @@ -480,13 +488,15 @@ def _try_obtain_certificates(
def _run_fleet_api_grpc_rere(
address: str,
state_factory: StateFactory,
ffs_factory: FfsFactory,
certificates: Optional[Tuple[bytes, bytes, bytes]],
interceptors: Optional[Sequence[grpc.ServerInterceptor]] = None,
) -> grpc.Server:
"""Run Fleet API (gRPC, request-response)."""
# Create Fleet API gRPC server
fleet_servicer = FleetServicer(
state_factory=state_factory,
ffs_factory=ffs_factory,
)
fleet_add_servicer_to_server_fn = add_FleetServicer_to_server
fleet_grpc_server = generic_create_grpc_server(
Expand Down Expand Up @@ -610,6 +620,11 @@ def _add_args_common(parser: argparse.ArgumentParser) -> None:
"Flower will just create a state in memory.",
default=DATABASE,
)
parser.add_argument(
"--base-dir",
help="The base directory to store the objects.",
default=BASE_DIR,
)
parser.add_argument(
"--auth-list-public-keys",
type=str,
Expand Down
1 change: 1 addition & 0 deletions src/py/flwr/server/driver/grpc_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def _init_run(self) -> None:
run_id=res.run.run_id,
fab_id=res.run.fab_id,
fab_version=res.run.fab_version,
fab_hash=res.run.fab_hash,
override_config=user_config_from_proto(res.run.override_config),
)

Expand Down
8 changes: 7 additions & 1 deletion src/py/flwr/server/driver/grpc_driver_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,12 @@ class TestGrpcDriver(unittest.TestCase):
def setUp(self) -> None:
"""Initialize mock GrpcDriverStub and Driver instance before each test."""
mock_response = Mock(
run=Run(run_id=61016, fab_id="mock/mock", fab_version="v1.0.0")
run=Run(
run_id=61016,
fab_id="mock/mock",
fab_version="v1.0.0",
fab_hash="mock/mock",
)
)
self.mock_stub = Mock()
self.mock_channel = Mock()
Expand All @@ -55,6 +60,7 @@ def test_init_grpc_driver(self) -> None:
self.assertEqual(self.driver.run.run_id, 61016)
self.assertEqual(self.driver.run.fab_id, "mock/mock")
self.assertEqual(self.driver.run.fab_version, "v1.0.0")
self.assertEqual(self.driver.run.fab_hash, "mock/mock")
self.mock_stub.GetRun.assert_called_once()

def test_get_nodes(self) -> None:
Expand Down
6 changes: 4 additions & 2 deletions src/py/flwr/server/driver/inmemory_driver_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def setUp(self) -> None:
run_id=61016,
fab_id="mock/mock",
fab_version="v1.0.0",
fab_hash="mock/mock",
override_config={"test_key": "test_value"},
)
state_factory = MagicMock(state=lambda: self.state)
Expand All @@ -99,6 +100,7 @@ def test_get_run(self) -> None:
"""Test the InMemoryDriver starting with run_id."""
# Assert
self.assertEqual(self.driver.run.run_id, 61016)
self.assertEqual(self.driver.run.fab_hash, "mock/mock")
self.assertEqual(self.driver.run.fab_id, "mock/mock")
self.assertEqual(self.driver.run.fab_version, "v1.0.0")
self.assertEqual(self.driver.run.override_config["test_key"], "test_value")
Expand Down Expand Up @@ -227,7 +229,7 @@ def test_task_store_consistency_after_push_pull_sqlitestate(self) -> None:
# Prepare
state = StateFactory("").state()
self.driver = InMemoryDriver(
state.create_run("", "", {}), MagicMock(state=lambda: state)
state.create_run(None, None, "", {}), MagicMock(state=lambda: state)
)
msg_ids, node_id = push_messages(self.driver, self.num_nodes)
assert isinstance(state, SqliteState)
Expand All @@ -253,7 +255,7 @@ def test_task_store_consistency_after_push_pull_inmemory_state(self) -> None:
# Prepare
state_factory = StateFactory(":flwr-in-memory-state:")
state = state_factory.state()
self.driver = InMemoryDriver(state.create_run("", "", {}), state_factory)
self.driver = InMemoryDriver(state.create_run("", "", None, {}), state_factory)
msg_ids, node_id = push_messages(self.driver, self.num_nodes)
assert isinstance(state, InMemoryState)

Expand Down
Loading
Loading