Skip to content

Commit

Permalink
feat(framework) Add override config to Run
Browse files Browse the repository at this point in the history
  • Loading branch information
charlesbvll committed Jul 7, 2024
1 parent d3aec92 commit 8587e73
Show file tree
Hide file tree
Showing 15 changed files with 120 additions and 55 deletions.
25 changes: 19 additions & 6 deletions src/py/flwr/client/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import time
from dataclasses import dataclass
from logging import DEBUG, ERROR, INFO, WARN
from pathlib import Path
from typing import Callable, ContextManager, Dict, Optional, Tuple, Type, Union

from cryptography.hazmat.primitives.asymmetric import ec
Expand All @@ -29,6 +30,7 @@
from flwr.client.typing import ClientFnExt
from flwr.common import GRPC_MAX_MESSAGE_LENGTH, EventType, Message, event
from flwr.common.address import parse_address
from flwr.common.config import get_fused_config
from flwr.common.constant import (
MISSING_EXTRA_REST,
TRANSPORT_TYPE_GRPC_ADAPTER,
Expand All @@ -41,6 +43,7 @@
from flwr.common.logger import log, warn_deprecated_feature
from flwr.common.message import Error
from flwr.common.retry_invoker import RetryInvoker, RetryState, exponential
from flwr.common.typing import Run

from .grpc_adapter_client.connection import grpc_adapter
from .grpc_client.connection import grpc_connection
Expand Down Expand Up @@ -192,6 +195,7 @@ def _start_client_internal(
max_retries: Optional[int] = None,
max_wait_time: Optional[float] = None,
partition_id: Optional[int] = None,
flwr_dir: Optional[Path] = None,
) -> None:
"""Start a Flower client node which connects to a Flower server.
Expand Down Expand Up @@ -235,9 +239,16 @@ class `flwr.client.Client` (default: None)
The maximum duration before the client stops trying to
connect to the server in case of connection error.
If set to None, there is no limit to the total time.
partitioni_id: Optional[int] (default: None)
partition_id: Optional[int] (default: None)
The data partition index associated with this node. Better suited for
prototyping purposes.
flwr_dir: Optional[Path] (default: None)
The path containing installed Flower Apps.
By default, this value is equal to:
- `$FLWR_HOME/` if `$FLWR_HOME` is defined
- `$XDG_DATA_HOME/.flwr/` if `$XDG_DATA_HOME` is defined
- `$HOME/.flwr/` in all other cases
"""
if insecure is None:
insecure = root_certificates is None
Expand Down Expand Up @@ -315,8 +326,7 @@ def _on_backoff(retry_state: RetryState) -> None:
)

node_state = NodeState(partition_id=partition_id)
# run_id -> (fab_id, fab_version)
run_info: Dict[int, Tuple[str, str]] = {}
run_info: Dict[int, Run] = {}

while not app_state_tracker.interrupt:
sleep_duration: int = 0
Expand Down Expand Up @@ -371,13 +381,14 @@ def _on_backoff(retry_state: RetryState) -> None:
run_info[run_id] = get_run(run_id)
# If get_run is None, i.e., in grpc-bidi mode
else:
run_info[run_id] = ("", "")
run_info[run_id] = Run(run_id, "", "", {})

# Register context for this run
node_state.register_context(run_id=run_id)

# Retrieve context for this run
context = node_state.retrieve_context(run_id=run_id)
context.run_config = get_fused_config(run_info[run_id], flwr_dir)

# Create an error reply message that will never be used to prevent
# the used-before-assignment linting error
Expand All @@ -388,7 +399,9 @@ def _on_backoff(retry_state: RetryState) -> None:
# Handle app loading and task message
try:
# Load ClientApp instance
client_app: ClientApp = load_client_app_fn(*run_info[run_id])
client_app: ClientApp = load_client_app_fn(
run_info[run_id].fab_id, run_info[run_id].fab_version
)

# Execute ClientApp
reply_message = client_app(message=message, context=context)
Expand Down Expand Up @@ -573,7 +586,7 @@ def _init_connection(transport: Optional[str], server_address: str) -> Tuple[
Callable[[Message], None],
Optional[Callable[[], None]],
Optional[Callable[[], None]],
Optional[Callable[[int], Tuple[str, str]]],
Optional[Callable[[int], Run]],
]
],
],
Expand Down
3 changes: 2 additions & 1 deletion src/py/flwr/client/grpc_adapter_client/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from flwr.common.logger import log
from flwr.common.message import Message
from flwr.common.retry_invoker import RetryInvoker
from flwr.common.typing import Run


@contextmanager
Expand All @@ -45,7 +46,7 @@ def grpc_adapter( # pylint: disable=R0913
Callable[[Message], None],
Optional[Callable[[], None]],
Optional[Callable[[], None]],
Optional[Callable[[int], Tuple[str, str]]],
Optional[Callable[[int], Run]],
]
]:
"""Primitives for request/response-based interaction with a server via GrpcAdapter.
Expand Down
3 changes: 2 additions & 1 deletion src/py/flwr/client/grpc_client/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from flwr.common.grpc import create_channel
from flwr.common.logger import log
from flwr.common.retry_invoker import RetryInvoker
from flwr.common.typing import Run
from flwr.proto.transport_pb2 import ( # pylint: disable=E0611
ClientMessage,
Reason,
Expand Down Expand Up @@ -73,7 +74,7 @@ def grpc_connection( # pylint: disable=R0913, R0915
Callable[[Message], None],
Optional[Callable[[], None]],
Optional[Callable[[], None]],
Optional[Callable[[int], Tuple[str, str]]],
Optional[Callable[[int], Run]],
]
]:
"""Establish a gRPC connection to a gRPC server.
Expand Down
12 changes: 9 additions & 3 deletions src/py/flwr/client/grpc_rere_client/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from flwr.common.message import Message, Metadata
from flwr.common.retry_invoker import RetryInvoker
from flwr.common.serde import message_from_taskins, message_to_taskres
from flwr.common.typing import Run
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
CreateNodeRequest,
DeleteNodeRequest,
Expand Down Expand Up @@ -80,7 +81,7 @@ def grpc_request_response( # pylint: disable=R0913, R0914, R0915
Callable[[Message], None],
Optional[Callable[[], None]],
Optional[Callable[[], None]],
Optional[Callable[[int], Tuple[str, str]]],
Optional[Callable[[int], Run]],
]
]:
"""Primitives for request/response-based interaction with a server.
Expand Down Expand Up @@ -266,7 +267,7 @@ def send(message: Message) -> None:
# Cleanup
metadata = None

def get_run(run_id: int) -> Tuple[str, str]:
def get_run(run_id: int) -> Run:
# Call FleetAPI
get_run_request = GetRunRequest(run_id=run_id)
get_run_response: GetRunResponse = retry_invoker.invoke(
Expand All @@ -275,7 +276,12 @@ def get_run(run_id: int) -> Tuple[str, str]:
)

# Return fab_id and fab_version
return get_run_response.run.fab_id, get_run_response.run.fab_version
return Run(
run_id,
get_run_response.run.fab_id,
get_run_response.run.fab_version,
dict(get_run_response.run.override_config.items()),
)

try:
# Yield methods
Expand Down
14 changes: 10 additions & 4 deletions src/py/flwr/client/rest_client/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from flwr.common.message import Message, Metadata
from flwr.common.retry_invoker import RetryInvoker
from flwr.common.serde import message_from_taskins, message_to_taskres
from flwr.common.typing import Run
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
CreateNodeRequest,
CreateNodeResponse,
Expand Down Expand Up @@ -91,7 +92,7 @@ def http_request_response( # pylint: disable=,R0913, R0914, R0915
Callable[[Message], None],
Optional[Callable[[], None]],
Optional[Callable[[], None]],
Optional[Callable[[int], Tuple[str, str]]],
Optional[Callable[[int], Run]],
]
]:
"""Primitives for request/response-based interaction with a server.
Expand Down Expand Up @@ -344,16 +345,21 @@ def send(message: Message) -> None:
res.results, # pylint: disable=no-member
)

def get_run(run_id: int) -> Tuple[str, str]:
def get_run(run_id: int) -> Run:
# Construct the request
req = GetRunRequest(run_id=run_id)

# Send the request
res = _request(req, GetRunResponse, PATH_GET_RUN)
if res is None:
return "", ""
return Run(run_id, "", "", {})

return res.run.fab_id, res.run.fab_version
return Run(
run_id,
res.run.fab_id,
res.run.fab_version,
dict(res.run.override_config.items()),
)

try:
# Yield methods
Expand Down
1 change: 1 addition & 0 deletions src/py/flwr/server/driver/grpc_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ def _get_stub_and_run_id(self) -> Tuple[GrpcDriverStub, int]:
run_id=res.run.run_id,
fab_id=res.run.fab_id,
fab_version=res.run.fab_version,
override_config=dict(res.run.override_config.items()),
)

return self.stub, self._run.run_id
Expand Down
10 changes: 7 additions & 3 deletions src/py/flwr/server/driver/inmemory_driver_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,10 @@ def setUp(self) -> None:
for _ in range(self.num_nodes)
]
self.state.get_run.return_value = Run(
run_id=61016, fab_id="mock/mock", fab_version="v1.0.0"
run_id=61016,
fab_id="mock/mock",
fab_version="v1.0.0",
override_config={"test_key": "test_value"},
)
state_factory = MagicMock(state=lambda: self.state)
self.driver = InMemoryDriver(run_id=61016, state_factory=state_factory)
Expand All @@ -98,6 +101,7 @@ def test_get_run(self) -> None:
self.assertEqual(self.driver.run.run_id, 61016)
self.assertEqual(self.driver.run.fab_id, "mock/mock")
self.assertEqual(self.driver.run.fab_version, "v1.0.0")
self.assertEqual(self.driver.run.override_config["test_key"], "test_value")

def test_get_nodes(self) -> None:
"""Test retrieval of nodes."""
Expand Down Expand Up @@ -223,7 +227,7 @@ def test_task_store_consistency_after_push_pull_sqlitestate(self) -> None:
# Prepare
state = StateFactory("").state()
self.driver = InMemoryDriver(
state.create_run("", ""), MagicMock(state=lambda: state)
state.create_run("", "", {}), MagicMock(state=lambda: state)
)
msg_ids, node_id = push_messages(self.driver, self.num_nodes)
assert isinstance(state, SqliteState)
Expand All @@ -249,7 +253,7 @@ def test_task_store_consistency_after_push_pull_inmemory_state(self) -> None:
# Prepare
state_factory = StateFactory(":flwr-in-memory-state:")
state = state_factory.state()
self.driver = InMemoryDriver(state.create_run("", ""), state_factory)
self.driver = InMemoryDriver(state.create_run("", "", {}), state_factory)
msg_ids, node_id = push_messages(self.driver, self.num_nodes)
assert isinstance(state, InMemoryState)

Expand Down
6 changes: 5 additions & 1 deletion src/py/flwr/server/superlink/driver/driver_servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,11 @@ def CreateRun(
"""Create run ID."""
log(DEBUG, "DriverServicer.CreateRun")
state: State = self.state_factory.state()
run_id = state.create_run(request.fab_id, request.fab_version)
run_id = state.create_run(
request.fab_id,
request.fab_version,
dict(request.override_config.items()),
)
return CreateRunResponse(run_id=run_id)

def PushTaskIns(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ def test_successful_get_run_with_metadata(self) -> None:
self.state.create_node(
ping_interval=30, public_key=public_key_to_bytes(self._client_public_key)
)
run_id = self.state.create_run("", "")
run_id = self.state.create_run("", "", {})
request = GetRunRequest(run_id=run_id)
shared_secret = generate_shared_key(
self._client_private_key, self._server_public_key
Expand Down Expand Up @@ -359,7 +359,7 @@ def test_unsuccessful_get_run_with_metadata(self) -> None:
self.state.create_node(
ping_interval=30, public_key=public_key_to_bytes(self._client_public_key)
)
run_id = self.state.create_run("", "")
run_id = self.state.create_run("", "", {})
request = GetRunRequest(run_id=run_id)
client_private_key, _ = generate_key_pairs()
shared_secret = generate_shared_key(client_private_key, self._server_public_key)
Expand Down
4 changes: 3 additions & 1 deletion src/py/flwr/server/superlink/fleet/vce/vce_api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,9 @@ def register_messages_into_state(
) -> Dict[UUID, float]:
"""Register `num_messages` into the state factory."""
state: InMemoryState = state_factory.state() # type: ignore
state.run_ids[run_id] = Run(run_id=run_id, fab_id="Mock/mock", fab_version="v1.0.0")
state.run_ids[run_id] = Run(
run_id=run_id, fab_id="Mock/mock", fab_version="v1.0.0", override_config={}
)
# Artificially add TaskIns to state so they can be processed
# by the Simulation Engine logic
nodes_cycle = cycle(nodes_mapping.keys()) # we have more messages than supernodes
Expand Down
12 changes: 10 additions & 2 deletions src/py/flwr/server/superlink/state/in_memory_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,15 +275,23 @@ def get_node_id(self, client_public_key: bytes) -> Optional[int]:
"""Retrieve stored `node_id` filtered by `client_public_keys`."""
return self.public_key_to_node_id.get(client_public_key)

def create_run(self, fab_id: str, fab_version: str) -> int:
def create_run(
self,
fab_id: str,
fab_version: str,
override_config: Dict[str, str],
) -> int:
"""Create a new run for the specified `fab_id` and `fab_version`."""
# Sample a random int64 as run_id
with self.lock:
run_id = generate_rand_int_from_bytes(RUN_ID_NUM_BYTES)

if run_id not in self.run_ids:
self.run_ids[run_id] = Run(
run_id=run_id, fab_id=fab_id, fab_version=fab_version
run_id=run_id,
fab_id=fab_id,
fab_version=fab_version,
override_config=override_config,
)
return run_id
log(ERROR, "Unexpected run creation failure.")
Expand Down
23 changes: 18 additions & 5 deletions src/py/flwr/server/superlink/state/sqlite_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import re
import sqlite3
import time
from ast import literal_eval
from logging import DEBUG, ERROR
from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union, cast
from uuid import UUID, uuid4
Expand Down Expand Up @@ -63,7 +64,8 @@
CREATE TABLE IF NOT EXISTS run(
run_id INTEGER UNIQUE,
fab_id TEXT,
fab_version TEXT
fab_version TEXT,
overrides TEXT
);
"""

Expand Down Expand Up @@ -613,7 +615,12 @@ def get_node_id(self, client_public_key: bytes) -> Optional[int]:
return node_id
return None

def create_run(self, fab_id: str, fab_version: str) -> int:
def create_run(
self,
fab_id: str,
fab_version: str,
override_config: Dict[str, str],
) -> int:
"""Create a new run for the specified `fab_id` and `fab_version`."""
# Sample a random int64 as run_id
run_id = generate_rand_int_from_bytes(RUN_ID_NUM_BYTES)
Expand All @@ -622,8 +629,11 @@ def create_run(self, fab_id: str, fab_version: str) -> int:
query = "SELECT COUNT(*) FROM run WHERE run_id = ?;"
# If run_id does not exist
if self.query(query, (run_id,))[0]["COUNT(*)"] == 0:
query = "INSERT INTO run (run_id, fab_id, fab_version) VALUES (?, ?, ?);"
self.query(query, (run_id, fab_id, fab_version))
query = (
"INSERT INTO run (run_id, fab_id, fab_version, overrides)"
"VALUES (?, ?, ?, ?);"
)
self.query(query, (run_id, fab_id, fab_version, str(override_config)))
return run_id
log(ERROR, "Unexpected run creation failure.")
return 0
Expand Down Expand Up @@ -687,7 +697,10 @@ def get_run(self, run_id: int) -> Optional[Run]:
try:
row = self.query(query, (run_id,))[0]
return Run(
run_id=run_id, fab_id=row["fab_id"], fab_version=row["fab_version"]
run_id=run_id,
fab_id=row["fab_id"],
fab_version=row["fab_version"],
override_config=literal_eval(row["overrides"]),
)
except sqlite3.IntegrityError:
log(ERROR, "`run_id` does not exist.")
Expand Down
Loading

0 comments on commit 8587e73

Please sign in to comment.